From 3fff90eb437dce30fd83012f4c0e24f3fca041b2 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 14 Jan 2022 17:47:27 +0100 Subject: [WIP] Set dev version and add minigallery to quick start guide (#334) * change version and add minigallery in quickstart guide * remove ot.gpu from documentation because it is deprecated and bacckends should be used * start 0.8.2dev and description in releases.md * typo for gallery sinkhorn2 * test better doc update for files in .githib folder --- .circleci/config.yml | 8 +- RELEASES.md | 5 + docs/source/all.rst | 1 - docs/source/auto_examples/images/bak.png | Bin 304669 -> 0 bytes docs/source/auto_examples/images/sinkhorn.png | Bin 37204 -> 0 bytes docs/source/conf.py | 1 - docs/source/quickstart.rst | 219 ++++++++++---------------- ot/__init__.py | 2 +- 8 files changed, 96 insertions(+), 140 deletions(-) delete mode 100644 docs/source/auto_examples/images/bak.png delete mode 100644 docs/source/auto_examples/images/sinkhorn.png diff --git a/.circleci/config.yml b/.circleci/config.yml index f5cb756..5427979 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -106,10 +106,10 @@ jobs: echo "Deploying dev docs for ${CIRCLE_BRANCH}."; cd master cp -a /tmp/build/html/* .; - cp -a /tmp/build/html/.github .github; + cp -a /tmp/build/html/.github/* .github/; touch .nojekyll; git add -A; - git add -f .github/*.html ; + git add -f .github/* ; git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM})."; git push origin master; else @@ -146,10 +146,10 @@ jobs: git clean -xdf echo "Deploying dev docs for ${CIRCLE_BRANCH}."; cp -a /tmp/build/html/* .; - cp -a /tmp/build/html/.github .github; + cp -a /tmp/build/html/.github/* .github/; touch .nojekyll; git add -A; - git add -f .github/*.html ; + git add -f .github/* ; git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM})."; git push origin master; diff --git a/RELEASES.md b/RELEASES.md index 00af0fb..9b92d97 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,10 @@ # Releases +## 0.8.2dev Development + +#### New features + +- Better list of related examples in quick start guide with `minigallery` (PR #334) ## 0.8.1.0 *December 2021* diff --git a/docs/source/all.rst b/docs/source/all.rst index 6a07599..7f85a91 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -20,7 +20,6 @@ API and modules gromov optim da - gpu dr utils datasets diff --git a/docs/source/auto_examples/images/bak.png b/docs/source/auto_examples/images/bak.png deleted file mode 100644 index 25e7e8e..0000000 Binary files a/docs/source/auto_examples/images/bak.png and /dev/null differ diff --git a/docs/source/auto_examples/images/sinkhorn.png b/docs/source/auto_examples/images/sinkhorn.png deleted file mode 100644 index e003e13..0000000 Binary files a/docs/source/auto_examples/images/sinkhorn.png and /dev/null differ diff --git a/docs/source/conf.py b/docs/source/conf.py index 849e97c..163851f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -74,7 +74,6 @@ extensions = [ autosummary_generate = True - napoleon_numpy_docstring = True # Add any paths that contain templates here, relative to this directory. diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 232df7b..e74b019 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -207,13 +207,12 @@ The method implemented for solving the OT problem is the network simplex. It is implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the solver is quite efficient and uses sparsity of the solution. -.. hint:: - Examples of use for :any:`ot.emd` are available in : - - :any:`auto_examples/plot_OT_2D_samples` - - :any:`auto_examples/plot_OT_1D` - - :any:`auto_examples/plot_OT_L1_vs_L2` +.. minigallery:: ot.emd + :add-heading: Examples of use for :any:`ot.emd` + :heading-level: " + Computing Wasserstein distance @@ -255,11 +254,9 @@ the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2` when providing :code:`M = ot.dist(xs, xt, metric='euclidean')` to use the Euclidean distance. -.. hint:: - - An example of use for :any:`ot.emd2` is available in : - - - :any:`auto_examples/plot_compute_emd` +.. minigallery:: ot.emd2 + :add-heading: Examples of use for :any:`ot.emd2` + :heading-level: " Special cases @@ -416,17 +413,18 @@ of stochastic solvers for entropic regularized OT [18]_ [19]_. Those pure Pytho implementations are not optimized for speed but provide a robust implementation of algorithms in [18]_ [19]_. -.. hint:: - Examples of use for :any:`ot.sinkhorn` are available in : - - :any:`auto_examples/plot_OT_2D_samples` - - :any:`auto_examples/plot_OT_1D` - - :any:`auto_examples/plot_OT_1D_smooth` - - :any:`auto_examples/plot_stochastic` +.. minigallery:: ot.sinkhorn + :add-heading: Examples of use for :any:`ot.sinkhorn` + :heading-level: " +.. minigallery:: ot.sinkhorn2 + :add-heading: Examples of use for :any:`ot.sinkhorn2` + :heading-level: " -Other regularization -^^^^^^^^^^^^^^^^^^^^ + +Other regularizations +^^^^^^^^^^^^^^^^^^^^^ While entropic OT is the most common and favored in practice, there exists other kinds of regularizations. We provide in POT two specific solvers for other @@ -451,12 +449,9 @@ functions :any:`ot.smooth.smooth_ot_dual` or :any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='l2'` to choose the quadratic regularization. -.. hint:: - Examples of quadratic regularization are available in : - - - :any:`auto_examples/plot_OT_1D_smooth` - - :any:`auto_examples/plot_optim_OTreg` - +.. minigallery:: ot.smooth.smooth_ot_dual ot.smooth.smooth_ot_semi_dual ot.optim.cg + :add-heading: Examples of use of quadratic regularization + :heading-level: " Group Lasso regularization @@ -480,11 +475,9 @@ be solved using an efficient majoration minimization approach with convex group lasso and we provide a solver using generalized conditional gradient algorithm [7]_ in function :any:`ot.da.sinkhorn_l1l2_gl`. -.. hint:: - Examples of group Lasso regularization are available in: - - - :any:`auto_examples/domain-adaptation/plot_otda_classes` - - :any:`auto_examples/domain-adaptation/plot_otda_d2` +.. minigallery:: ot.da.SinkhornLpl1Transport ot.da.SinkhornL1l2Transport ot.da.sinkhorn_l1l2_gl ot.da.sinkhorn_lpl1_mm + :add-heading: Examples of group Lasso regularization + :heading-level: " Generic solvers @@ -520,10 +513,9 @@ generalized conditional gradient [7]_ implemented in :any:`ot.optim.gcg` that does not linearize the entropic term but relies on :any:`ot.sinkhorn` for its iterations. -.. hint:: - An example of generic solvers are available in : - - - :any:`auto_examples/plot_optim_OTreg` +.. minigallery:: ot.optim.cg ot.optim.gcg + :add-heading: Examples of the generic solvers + :heading-level: " Wasserstein Barycenters @@ -581,19 +573,15 @@ the matrix vector production in the Bregman projections by convolution operators. We provide an implementation of this algorithm in function :any:`ot.bregman.convolutional_barycenter2d`. -.. hint:: - Examples of Wasserstein (:meth:`ot.lp.barycenter`) and regularized Wasserstein - barycenter (:any:`ot.bregman.barycenter`) computation are available in : - - :any:`auto_examples/barycenters/plot_barycenter_1D` - - :any:`auto_examples/barycenters/plot_barycenter_lp_vs_entropic` - An example of convolutional barycenter - (:any:`ot.bregman.convolutional_barycenter2d`) computation - for 2D images is available - in : +.. minigallery:: ot.lp.barycenter ot.bregman.barycenter ot.barycenter + :add-heading: Examples of Wasserstein and regularized Wasserstein barycenters + :heading-level: " - - :any:`auto_examples/barycenters/plot_convolutional_barycenter` +.. minigallery:: ot.bregman.convolutional_barycenter2d + :add-heading: An example of convolutional barycenter (:any:`ot.bregman.convolutional_barycenter2d`) computation + :heading-level: " @@ -613,13 +601,9 @@ We provide a solver based on [20]_ in return a locally optimal support :math:`\{x_i\}` for uniform or given weights :math:`a`. - .. hint:: - - An example of the free support barycenter estimation is available - in : - - - :any:`auto_examples/barycenters/plot_free_support_barycenter` - +.. minigallery:: ot.lp.free_support_barycenter + :add-heading: Examples of free support barycenter estimation + :heading-level: " @@ -656,12 +640,10 @@ method proposed in [8]_ that estimates a continuous mapping approximating the barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping. - .. hint:: - - An example of the linear Monge mapping estimation is available - in : +.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.da.OT_mapping_linear + :add-heading: Examples of Monge mapping estimation + :heading-level: " - - :any:`auto_examples/domain-adaptation/plot_otda_linear_mapping` Domain adaptation classes ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -704,14 +686,11 @@ A list of the provided implementation is given in the following note. [14]_ * :any:`ot.da.MappingTransport`: Nonlinear mapping estimation [8]_ -.. hint:: - Examples of the use of OTDA classes are available in: +.. minigallery:: ot.da.SinkhornTransport ot.da.LinearTransport + :add-heading: Examples of the use of OTDA classes + :heading-level: " - - :any:`auto_examples/domain-adaptation/plot_otda_color_images` - - :any:`auto_examples/domain-adaptation/plot_otda_mapping` - - :any:`auto_examples/domain-adaptation/plot_otda_mapping_colors_images` - - :any:`auto_examples/domain-adaptation/plot_otda_semi_supervised` Other applications ------------------ @@ -746,11 +725,10 @@ respectively. Note that we also provide the Fisher discriminant estimator in :code:`autograd`, :any:`ot.dr` is not imported by default. If you want to use it you have to specifically import it with :code:`import ot.dr` . -.. hint:: +.. minigallery:: ot.dr.wda + :add-heading: Examples of the use of WDA + :heading-level: " - An example of the use of WDA is available in : - - - :any:`auto_examples/others/plot_WDA` Unbalanced optimal transport @@ -787,11 +765,9 @@ linear term. the log stabilized version of the algorithm [10]_. -.. hint:: - - Examples of the use of :any:`ot.sinkhorn_unbalanced` are available in : - - - :any:`auto_examples/unbalanced-partial/plot_UOT_1D` +.. minigallery:: ot.sinkhorn_unbalanced ot.sinkhorn_unbalanced2 ot.unbalanced.sinkhorn_unbalanced + :add-heading: Examples of Unbalanced OT + :heading-level: " Unbalanced Barycenters @@ -819,11 +795,10 @@ implemented the main function :any:`ot.barycenter_unbalanced`. the log stabilized version of the algorithm [10]_. -.. hint:: +.. minigallery:: ot.barycenter_unbalanced ot.unbalanced.barycenter_unbalanced + :add-heading: Examples of Unbalanced OT barycenters + :heading-level: " - Examples of the use of :any:`ot.barycenter_unbalanced` are available in : - - - :any:`auto_examples/unbalanced-partial/plot_UOT_barycenter_1D` Partial optimal transport @@ -865,11 +840,10 @@ is computed in :any:`ot.partial.partial_gromov_wasserstein` and in regularization of the problem. -.. hint:: - - Examples of the use of :any:`ot.partial` are available in: +.. minigallery:: ot.partial.partial_wasserstein ot.partial.partial_gromov_wasserstein + :add-heading: Examples of Partial OT + :heading-level: " - - :any:`auto_examples/unbalanced-partial/plot_partial_wass_and_gromov` @@ -898,6 +872,12 @@ There also exists an entropic regularized variant of GW that has been proposed i [12]_ and we provide an implementation of their algorithm in :any:`ot.gromov.entropic_gromov_wasserstein`. + +.. minigallery:: ot.gromov.gromov_wasserstein ot.gromov.entropic_gromov_wasserstein ot.gromov.fused_gromov_wasserstein ot.gromov.gromov_wasserstein2 + :add-heading: Examples of computation of GW, regularized G and FGW + :heading-level: " + + Note that similarly to Wasserstein distance GW allows for the definition of GW barycenters that can be expressed as @@ -919,59 +899,15 @@ graphs for instance and also provide computable barycenters. The implementations of FGW and FGW barycenter is provided in functions :any:`ot.gromov.fused_gromov_wasserstein` and :any:`ot.gromov.fgw_barycenters`. -.. hint:: - - Examples of computation of GW, regularized G and FGW are available in: - - :any:`auto_examples/gromov/plot_gromov` - - :any:`auto_examples/gromov/plot_fgw` +.. minigallery:: ot.gromov.gromov_barycenters ot.gromov.fgw_barycenters + :add-heading: Examples of GW, regularized G and FGW barycenters + :heading-level: " - Examples of GW, regularized GW and FGW barycenters are available in: - - :any:`auto_examples/gromov/plot_gromov_barycenter` - - :any:`auto_examples/gromov/plot_barycenter_fgw` - -GPU acceleration -^^^^^^^^^^^^^^^^ - -.. warning:: - - The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and - should not be used. The GPU implementation (in Pytorch for instance) can be - used with the novel backends using the compatible functions from POT. - - -We provide several implementation of our OT solvers in :any:`ot.gpu`. Those -implementations use the :code:`cupy` toolbox that obviously need to be installed. - - -.. note:: - - Several implementations of POT functions (mainly those relying on linear - algebra) have been implemented in :any:`ot.gpu`. Here is a short list on the - main entries: - - - :meth:`ot.gpu.dist`: computation of distance matrix - - :meth:`ot.gpu.sinkhorn`: computation of sinkhorn - - :meth:`ot.gpu.sinkhorn_lpl1_mm`: computation of sinkhorn + group lasso - -Note that while the :any:`ot.gpu` module has been designed to be compatible with -POT, calling its function with :any:`numpy` arrays will incur a large overhead due to -the memory copy of the array on GPU prior to computation and conversion of the -array after computation. To avoid this overhead, we provide functions -:meth:`ot.gpu.to_gpu` and :meth:`ot.gpu.to_np` that perform the conversion -explicitly. - -.. warning:: - - Note that due to the hard dependency on :code:`cupy`, :any:`ot.gpu` is not - imported by default. If you want to - use it you have to specifically import it with :code:`import ot.gpu` . - - -Solving OT with Multiple backends ---------------------------------- +Solving OT with Multiple backends on CPU/GPU +-------------------------------------------- .. _backends_section: @@ -1002,7 +938,21 @@ the function will be the same type as the inputs and on the same device. When possible all computations are done on the same device and also when possible the output will be differentiable with respect to the input of the function. +GPU acceleration +^^^^^^^^^^^^^^^^ + +The backends provide automatic computations/compatibility on GPU for most of the +POT functions. +Note that all solvers relying on the exact OT solver en C++ will need to solve the +problem on CPU which can incur some memory copy overhead and be far from optimal +when all other computations are done on GPU. They will still work on array on +GPU since the copy is done automatically. +Some of the functions that rely on the exact C++ solver are: + +- :any:`ot.emd`, :any:`ot.emd2` +- :any:`ot.gromov_wasserstein`, :any:`ot.gromov_wasserstein2` +- :any:`ot.optim.cg` List of compatible Backends ^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -1010,18 +960,21 @@ List of compatible Backends - `Numpy `_ (all functions and solvers) - `Pytorch `_ (all outputs differentiable w.r.t. inputs) - `Jax `_ (Some functions are differentiable some require a wrapper) +- `Tensorflow `_ (all outputs differentiable w.r.t. inputs) +- `Cupy `_ (no differentiation, GPU only) + -List of compatible functions +List of compatible modules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This list will get longer for new releases and will hopefully disappear when POT become fully implemented with the backend. -- :any:`ot.emd` -- :any:`ot.emd2` -- :any:`ot.sinkhorn` -- :any:`ot.sinkhorn2` -- :any:`ot.dist` +- :any:`ot.bregman` +- :any:`ot.gromov` (some functions use CPU only solvers with copy overhead) +- :any:`ot.optim` (some functions use CPU only solvers with copy overhead) +- :any:`ot.sliced` +- :any:`ot.utils` (partial) FAQ diff --git a/ot/__init__.py b/ot/__init__.py index f55819d..1ea7403 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -50,7 +50,7 @@ from .gromov import (gromov_wasserstein, gromov_wasserstein2, # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.8.1.0" +__version__ = "0.8.2dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', -- cgit v1.2.3 From 5861209f27fe8e022eca2ed2c8d0bb1da4a1146b Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 17 Jan 2022 13:45:58 +0100 Subject: [MRG] Default pygment color for doc (#335) * back to default pygment * add images back * move static images and make it work --- docs/source/_static/images/bak.png | Bin 0 -> 304669 bytes docs/source/_static/images/sinkhorn.png | Bin 0 -> 37204 bytes docs/source/conf.py | 2 +- examples/plot_Intro_OT.py | 4 ++-- 4 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 docs/source/_static/images/bak.png create mode 100644 docs/source/_static/images/sinkhorn.png diff --git a/docs/source/_static/images/bak.png b/docs/source/_static/images/bak.png new file mode 100644 index 0000000..25e7e8e Binary files /dev/null and b/docs/source/_static/images/bak.png differ diff --git a/docs/source/_static/images/sinkhorn.png b/docs/source/_static/images/sinkhorn.png new file mode 100644 index 0000000..e003e13 Binary files /dev/null and b/docs/source/_static/images/sinkhorn.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index 163851f..d1b8426 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -141,7 +141,7 @@ exclude_patterns = [] #show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = 'default' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py index f282950..219aa51 100644 --- a/examples/plot_Intro_OT.py +++ b/examples/plot_Intro_OT.py @@ -58,7 +58,7 @@ help(ot.dist) # number of Bakeries to Cafés in a City (in this case Manhattan). We did a # quick google map search in Manhattan for bakeries and Cafés: # -# .. image:: images/bak.png +# .. image:: ../_static/images/bak.png # :align: center # :alt: bakery-cafe-manhattan # :width: 600px @@ -233,7 +233,7 @@ print('Wasserstein loss (EMD) = {0:.2f}'.format(W)) # The Sinkhorn algorithm is very simple to code. You can implement it directly # using the following pseudo-code # -# .. image:: images/sinkhorn.png +# .. image:: ../_static/images/sinkhorn.png # :align: center # :alt: Sinkhorn algorithm # :width: 440px -- cgit v1.2.3 From 263c5842664c1dff4f8e58111d6bddb33927539e Mon Sep 17 00:00:00 2001 From: Bastian Rieck Date: Wed, 19 Jan 2022 08:39:04 +0100 Subject: [MRG] Fix instantiation of `ValFunction` (which raises a warning with PyTorch) (#338) * Not instantiating `ValFunction` `ValFunction` should not be instantiated since `autograd` functions are supposed to only ever use static methods. This solves a warning message raised by PyTorch. * Updated release information * Fixed PR number --- RELEASES.md | 4 ++++ ot/backend.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 9b92d97..c6ab9c3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,10 @@ - Better list of related examples in quick start guide with `minigallery` (PR #334) +#### Closed issues + +- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR #338) + ## 0.8.1.0 *December 2021* diff --git a/ot/backend.py b/ot/backend.py index 58b652b..6e0bc3d 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1397,7 +1397,7 @@ class TorchBackend(Backend): def set_gradients(self, val, inputs, grads): - Func = self.ValFunction() + Func = self.ValFunction res = Func.apply(val, grads, *inputs) -- cgit v1.2.3 From d7c709e2bae3bafec9efad87e758919c8db61933 Mon Sep 17 00:00:00 2001 From: Jakub Zadrożny Date: Fri, 21 Jan 2022 08:50:19 +0100 Subject: [MRG] Implement Sinkhorn in log-domain for WDA (#336) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [MRG] Implement Sinkhorn in log-domain for WDA * for small values of the regularization parameter (reg) the current implementation runs into numerical issues (nans and infs) * this can be resolved by using log-domain implementation of the sinkhorn algorithm * Add feature to RELEASES and contributor name * Add 'sinkhorn_method' parameter to WDA * use the standard Sinkhorn solver by default (faster) * use log-domain Sinkhorn if asked by the user Co-authored-by: Jakub Zadrożny Co-authored-by: Rémi Flamary --- RELEASES.md | 2 ++ ot/dr.py | 44 ++++++++++++++++++++++++++++++++++++++++++-- test/test_dr.py | 22 ++++++++++++++++++++++ 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index c6ab9c3..a5fcbe1 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,8 @@ #### New features - Better list of related examples in quick start guide with `minigallery` (PR #334) +- Add optional log-domain Sinkhorn implementation in WDA to support smaller values + of the regularization parameter (PR #336) #### Closed issues diff --git a/ot/dr.py b/ot/dr.py index 1671ca0..0955c55 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -11,6 +11,7 @@ Dimension reduction with OT # Author: Remi Flamary # Minhui Huang +# Jakub Zadrozny # # License: MIT License @@ -43,6 +44,28 @@ def sinkhorn(w1, w2, M, reg, k): return G +def logsumexp(M, axis): + r"""Log-sum-exp reduction compatible with autograd (no numpy implementation) + """ + amax = np.amax(M, axis=axis, keepdims=True) + return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis) + + +def sinkhorn_log(w1, w2, M, reg, k): + r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd) + """ + Mr = -M / reg + ui = np.zeros((M.shape[0],)) + vi = np.zeros((M.shape[1],)) + log_w1 = np.log(w1) + log_w2 = np.log(w2) + for i in range(k): + vi = log_w2 - logsumexp(Mr + ui[:, None], 0) + ui = log_w1 - logsumexp(Mr + vi[None, :], 1) + G = np.exp(ui[:, None] + Mr + vi[None, :]) + return G + + def split_classes(X, y): r"""split samples in :math:`\mathbf{X}` by classes in :math:`\mathbf{y}` """ @@ -110,7 +133,7 @@ def fda(X, y, p=2, reg=1e-16): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): +def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False): r""" Wasserstein Discriminant Analysis :ref:`[11] ` @@ -126,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no - :math:`W` is entropic regularized Wasserstein distances - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sparse cost matrices, you should use the + :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical + errors, but can be slow in practice. + Parameters ---------- X : ndarray, shape (n, d) @@ -139,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no solver : None | str, optional None for steepest descent or 'TrustRegions' for trust regions algorithm else should be a pymanopt.solvers + sinkhorn_method : str + method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log' P0 : ndarray, shape (d, p) Initial starting point for projection. normalize : bool, optional @@ -161,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. """ # noqa + if sinkhorn_method.lower() == 'sinkhorn': + sinkhorn_solver = sinkhorn + elif sinkhorn_method.lower() == 'sinkhorn_log': + sinkhorn_solver = sinkhorn_log + else: + raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method) + mx = np.mean(X) X -= mx.reshape((1, -1)) @@ -193,7 +233,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k) + G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k) if j == 0: loss_w += np.sum(G * M) else: diff --git a/test/test_dr.py b/test/test_dr.py index 741f2ad..6d7fc9a 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -60,6 +60,28 @@ def test_wda(): np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_wda_low_reg(): + + n_samples = 100 # nb samples in source and target datasets + np.random.seed(0) + + # generate gaussian dataset + xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples) + + n_features_noise = 8 + + xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise))) + + p = 2 + + Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log') + + projwda(xs) + + np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_wda_normalized(): -- cgit v1.2.3 From 71a57c68ea9eb2bc948c4dd1cce9928f34bf20e8 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 28 Jan 2022 17:40:16 +0100 Subject: [MRG] Backend implementation of the free support barycenter (#340) * backend version barycenter * new tests * cleanup release file and doc * f*ing pep8 * remove unused variable --- RELEASES.md | 5 ++++- ot/lp/__init__.py | 28 +++++++++++++++------------- test/test_ot.py | 17 +++++++++++++++++ test/test_utils.py | 2 +- 4 files changed, 37 insertions(+), 15 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index a5fcbe1..94c853b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,6 @@ # Releases + ## 0.8.2dev Development #### New features @@ -7,10 +8,12 @@ - Better list of related examples in quick start guide with `minigallery` (PR #334) - Add optional log-domain Sinkhorn implementation in WDA to support smaller values of the regularization parameter (PR #336) +- Backend implementation for `ot.lp.free_support_barycenter` (PR #340) #### Closed issues -- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR #338) +- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR + #338) ## 0.8.1.0 *December 2021* diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 5da897d..2ff7c1f 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -535,18 +535,18 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Parameters ---------- - measures_locations : list of N (k_i,d) numpy.ndarray + measures_locations : list of N (k_i,d) array-like The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space (:math:`k_i` can be different for each element of the list) - measures_weights : list of N (k_i,) numpy.ndarray + measures_weights : list of N (k_i,) array-like Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one representing the weights of each discrete input measure - X_init : (k,d) np.ndarray + X_init : (k,d) array-like Initialization of the support locations (on `k` atoms) of the barycenter - b : (k,) np.ndarray + b : (k,) array-like Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (N,) np.ndarray + weights : (N,) array-like Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional @@ -564,7 +564,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None Returns ------- - X : (k,d) np.ndarray + X : (k,d) array-like Support locations (on k atoms) of the barycenter @@ -577,15 +577,17 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None """ + nx = get_backend(*measures_locations,*measures_weights,X_init) + iter_count = 0 N = len(measures_locations) k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = np.ones((k,)) / k + b = nx.ones((k,),type_as=X_init) / k if weights is None: - weights = np.ones((N,)) / N + weights = nx.ones((N,),type_as=X_init) / N X = X_init @@ -596,15 +598,15 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None while (displacement_square_norm > stopThr and iter_count < numItermax): - T_sum = np.zeros((k, d)) + T_sum = nx.zeros((k, d),type_as=X_init) + - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, - weights.tolist()): + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): M_i = dist(X, measure_locations_i) T_i = emd(b, measure_weights_i, M_i, numThreads=numThreads) - T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) + T_sum = T_sum + weight_i * 1. / b[:,None] * nx.dot(T_i, measure_locations_i) - displacement_square_norm = np.sum(np.square(T_sum - X)) + displacement_square_norm = nx.sum((T_sum - X)**2) if log: displacement_square_norms.append(displacement_square_norm) diff --git a/test/test_ot.py b/test/test_ot.py index 53edf4f..e8e2d97 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -302,6 +302,23 @@ def test_free_support_barycenter(): np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) +def test_free_support_barycenter_backends(nx): + + measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))] + measures_weights = [np.array([1.]), np.array([1.])] + X_init = np.array([-12.]).reshape((1, 1)) + + X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) + + measures_locations2 = [nx.from_numpy(x) for x in measures_locations] + measures_weights2 = [nx.from_numpy(x) for x in measures_weights] + X_init2 = nx.from_numpy(X_init) + + X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2) + + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + @pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None] diff --git a/test/test_utils.py b/test/test_utils.py index 6b476b2..8b23c22 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -122,7 +122,7 @@ def test_dist(): 'braycurtis', 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', - 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule' + 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule' ] # those that support weights metrics = ['mahalanobis', 'seuclidean'] # do not support weights depending on scipy's version -- cgit v1.2.3 From a5e0f0d40d5046a6639924347ef97e2ac80ad0c9 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 2 Feb 2022 11:53:12 +0100 Subject: [MRG] Add weak OT solver (#341) * add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation --- README.md | 3 + RELEASES.md | 8 ++- docs/source/all.rst | 1 + examples/others/plot_WeakOT_VS_OT.py | 98 +++++++++++++++++++++++++++ examples/plot_OT_2D_samples.py | 5 +- ot/__init__.py | 5 +- ot/gromov.py | 16 +++++ ot/lp/__init__.py | 9 ++- ot/lp/cvx.py | 1 - ot/utils.py | 12 +++- ot/weak.py | 124 +++++++++++++++++++++++++++++++++++ test/test_bregman.py | 13 ++-- test/test_ot.py | 2 +- test/test_utils.py | 18 ++++- test/test_weak.py | 54 +++++++++++++++ 15 files changed, 343 insertions(+), 26 deletions(-) create mode 100644 examples/others/plot_WeakOT_VS_OT.py create mode 100644 ot/weak.py create mode 100644 test/test_weak.py diff --git a/README.md b/README.md index 17fbe81..a7627df 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ POT provides the following generic OT solvers (links to examples): * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. +* Weak OT solver between empirical distributions [39] * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] @@ -301,3 +302,5 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020 [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. + +[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 94c853b..4d05582 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,10 +5,12 @@ #### New features -- Better list of related examples in quick start guide with `minigallery` (PR #334) +- Better list of related examples in quick start guide with `minigallery` (PR #334). - Add optional log-domain Sinkhorn implementation in WDA to support smaller values - of the regularization parameter (PR #336) -- Backend implementation for `ot.lp.free_support_barycenter` (PR #340) + of the regularization parameter (PR #336). +- Backend implementation for `ot.lp.free_support_barycenter` (PR #340). +- Add weak OT solver + example (PR #341). + #### Closed issues diff --git a/docs/source/all.rst b/docs/source/all.rst index 7f85a91..76d2ff5 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -28,6 +28,7 @@ API and modules unbalanced partial sliced + weak .. autosummary:: :toctree: ../modules/generated/ diff --git a/examples/others/plot_WeakOT_VS_OT.py b/examples/others/plot_WeakOT_VS_OT.py new file mode 100644 index 0000000..a29c875 --- /dev/null +++ b/examples/others/plot_WeakOT_VS_OT.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Weak Optimal Transport VS exact Optimal Transport +==================================================== + +Illustration of 2D optimal transport between distributions that are weighted +sum of diracs. The OT matrix is plotted with the samples. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +############################################################################## +# Generate data an plot it +# ------------------------ + +#%% parameters and data generation + +n = 50 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +pl.figure(2) +pl.imshow(M, interpolation='nearest') +pl.title('Cost matrix M') + + +############################################################################## +# Compute Weak OT and exact OT solutions +# -------------------------------------- + +#%% EMD + +G0 = ot.emd(a, b, M) + +#%% Weak OT + +Gweak = ot.weak_optimal_transport(xs, xt, a, b) + + +############################################################################## +# Plot weak OT and exact OT solutions +# -------------------------------------- + +pl.figure(3, (8, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(G0, interpolation='nearest') +pl.title('OT matrix') + +pl.subplot(1, 2, 2) +pl.imshow(Gweak, interpolation='nearest') +pl.title('Weak OT matrix') + +pl.figure(4, (8, 5)) + +pl.subplot(1, 2, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('OT matrix with samples') + +pl.subplot(1, 2, 2) +ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1]) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Weak OT matrix with samples') diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index af1bc12..c3a7cd8 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -42,7 +42,6 @@ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples # loss matrix M = ot.dist(xs, xt) -M /= M.max() ############################################################################## # Plot data @@ -87,7 +86,7 @@ pl.title('OT matrix with samples') #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Gs = ot.sinkhorn(a, b, M, lambd) @@ -112,7 +111,7 @@ pl.show() #%% sinkhorn # reg term -lambd = 1e-3 +lambd = 1e-1 Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd) diff --git a/ot/__init__.py b/ot/__init__.py index 1ea7403..7253318 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -36,6 +36,7 @@ from . import unbalanced from . import partial from . import backend from . import regpath +from . import weak # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -46,7 +47,7 @@ from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) - +from .weak import weak_optimal_transport # utils functions from .utils import dist, unif, tic, toc, toq @@ -59,5 +60,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', - 'max_sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/gromov.py b/ot/gromov.py index 6544260..b7e7949 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F - :math:`\mathbf{q}`: distribution in the target space - `L`: loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Parameters ---------- C1 : array-like, shape (ns, ns) @@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` Parameters @@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + Note that when using backends, this loss function is differentiable wrt the marices and weights for quadratic loss using the gradients from [38]_. diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2ff7c1f..d9b6fa9 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,6 +26,8 @@ from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend + + __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -220,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] `. @@ -358,7 +361,8 @@ def emd2(a, b, M, processes=1, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. Uses the algorithm proposed in :ref:`[1] `. @@ -622,3 +626,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X + diff --git a/ot/lp/cvx.py b/ot/lp/cvx.py index 869d450..fbf3c0e 100644 --- a/ot/lp/cvx.py +++ b/ot/lp/cvx.py @@ -11,7 +11,6 @@ import numpy as np import scipy as sp import scipy.sparse as sps - try: import cvxopt from cvxopt import solvers, matrix, spmatrix diff --git a/ot/utils.py b/ot/utils.py index e6c93c8..725ca00 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -116,7 +116,7 @@ def proj_simplex(v, z=1): return w -def unif(n): +def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). @@ -124,13 +124,19 @@ def unif(n): ---------- n : int number of bins in the histogram + type_as : array_like + array of the same type of the expected output (numpy/pytorch/jax) Returns ------- - h : np.array (`n`,) + h : array_like (`n`,) histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ - return np.ones((n,)) / n + if type_as is None: + return np.ones((n,)) / n + else: + nx = get_backend(type_as) + return nx.ones((n,)) / n def clean_zeros(a, b, M): diff --git a/ot/weak.py b/ot/weak.py new file mode 100644 index 0000000..f7d5b23 --- /dev/null +++ b/ot/weak.py @@ -0,0 +1,124 @@ +""" +Weak optimal ransport solvers +""" + +# Author: Remi Flamary +# +# License: MIT License + +from .backend import get_backend +from .optim import cg +import numpy as np + +__all__ = ['weak_optimal_transport'] + + +def weak_optimal_transport(Xa, Xb, a=None, b=None, verbose=False, log=False, G0=None, **kwargs): + r"""Solves the weak optimal transport problem between two empirical distributions + + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \|X_a-diag(1/a)\gammaX_b\|_F^2 + + s.t. \ \gamma \mathbf{1} = \mathbf{a} + + \gamma^T \mathbf{1} = \mathbf{b} + + \gamma \geq 0 + + where : + + - :math:`X_a` :math:`X_b` are the sample matrices. + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights + + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed + in :ref:`[39] `. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + numItermaxEmd : int, optional + Max number of iterations for emd + stopThr : float, optional + Stop threshold on the relative variation (>0) + stopThr2 : float, optional + Stop threshold on the absolute variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status + + + .. _references-weak: + References + ---------- + .. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). + Kantorovich duality for general transport costs and applications. + Journal of Functional Analysis, 273(11), 3327-3405. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT + ot.optim.cg : General regularized OT + """ + + nx = get_backend(Xa, Xb) + + Xa2 = nx.to_numpy(Xa) + Xb2 = nx.to_numpy(Xb) + + if a is None: + a2 = np.ones((Xa.shape[0])) / Xa.shape[0] + else: + a2 = nx.to_numpy(a) + if b is None: + b2 = np.ones((Xb.shape[0])) / Xb.shape[0] + else: + b2 = nx.to_numpy(b) + + # init uniform + if G0 is None: + T0 = a2[:, None] * b2[None, :] + else: + T0 = nx.to_numpy(G0) + + # weak OT loss + def f(T): + return np.dot(a2, np.sum((Xa2 - np.dot(T, Xb2) / a2[:, None])**2, 1)) + + # weak OT gradient + def df(T): + return -2 * np.dot(Xa2 - np.dot(T, Xb2) / a2[:, None], Xb2.T) + + # solve with conditional gradient and return solution + if log: + res, log = cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs) + log['u'] = nx.from_numpy(log['u'], type_as=Xa) + log['v'] = nx.from_numpy(log['v'], type_as=Xb) + return nx.from_numpy(res, type_as=Xa), log + else: + return nx.from_numpy(cg(a2, b2, 0, 1, f, df, T0, log=log, verbose=verbose, **kwargs), type_as=Xa) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6e90aa4..1419f9b 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -60,7 +60,7 @@ def test_convergence_warning(method): ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) -def test_not_impemented_method(): +def test_not_implemented_method(): # test sinkhorn w = 10 n = w ** 2 @@ -635,7 +635,7 @@ def test_wasserstein_bary_2d(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -667,7 +667,7 @@ def test_wasserstein_bary_2d_debiased(nx, method): with pytest.raises(NotImplementedError): ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) else: - bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass_np, log_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method, verbose=True, log=True) bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) @@ -940,14 +940,11 @@ def test_screenkhorn(nx): bb = nx.from_numpy(b) M_nx = nx.from_numpy(M, type_as=ab) - # np sinkhorn - G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-1, uniform=True, verbose=True)) # check marginals - np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) np.testing.assert_allclose(G_sink.sum(1), G_screen.sum(1), atol=1e-02) diff --git a/test/test_ot.py b/test/test_ot.py index e8e2d97..3e2d845 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -232,7 +232,7 @@ def test_emd2_multi(): # Gaussian distributions a = gauss(n, m=20, s=5) # m= mean, s= std - ls = np.arange(20, 500, 20) + ls = np.arange(20, 500, 100) nb = len(ls) b = np.zeros((n, nb)) for i in range(nb): diff --git a/test/test_utils.py b/test/test_utils.py index 8b23c22..5ad167b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -62,12 +62,12 @@ def test_tic_toc(): import time ot.tic() - time.sleep(0.5) + time.sleep(0.1) t = ot.toc() t2 = ot.toq() # test timing - np.testing.assert_allclose(0.5, t, rtol=1e-1, atol=1e-1) + np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1) # test toc vs toq np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1) @@ -94,10 +94,22 @@ def test_unif(): np.testing.assert_allclose(1, np.sum(u)) -def test_dist(): +def test_unif_backend(nx): n = 100 + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + u = ot.unif(n, type_as=tp) + + np.testing.assert_allclose(1, np.sum(nx.to_numpy(u)), atol=1e-6) + + +def test_dist(): + + n = 10 + rng = np.random.RandomState(0) x = rng.randn(n, 2) diff --git a/test/test_weak.py b/test/test_weak.py new file mode 100644 index 0000000..c4c3278 --- /dev/null +++ b/test/test_weak.py @@ -0,0 +1,54 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary +# +# License: MIT License + +import ot +import numpy as np + + +def test_weak_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G, log = ot.weak_optimal_transport(xs, xt, u, u, log=True) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + # chaeck that identity is recovered + G = ot.weak_optimal_transport(xs, xs, G0=np.eye(n) / n) + + # check G is identity + np.testing.assert_allclose(G, np.eye(n) / n) + + # check constraints + np.testing.assert_allclose(u, G.sum(1)) + np.testing.assert_allclose(u, G.sum(0)) + + +def test_weak_ot_bakends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + G = ot.weak_optimal_transport(xs, xt, u, u) + + xs2 = nx.from_numpy(xs) + xt2 = nx.from_numpy(xt) + u2 = nx.from_numpy(u) + + G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2) + + np.testing.assert_allclose(nx.to_numpy(G2), G) -- cgit v1.2.3 From 50c0f17d00e3492c4d56a356af30cf00d6d07913 Mon Sep 17 00:00:00 2001 From: Cédric Vincent-Cuaz Date: Fri, 11 Feb 2022 10:53:38 +0100 Subject: [MRG] GW dictionary learning (#319) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add fgw dictionary learning feature * add fgw dictionary learning feature * plot gromov wasserstein dictionary learning * Update __init__.py * fix pep8 errors exact E501 line too long * fix last pep8 issues * add unitary tests for (F)GW dictionary learning without using autodifferentiable functions * correct tests for (F)GW dictionary learning without using autodiff * correct tests for (F)GW dictionary learning without using autodiff * fix docs and notations * answer to review: improve tests, docs, examples + make node weights optional * fix pep8 and examples * improve docs + tests + thumbnail * make example faster * improve ex * update README.md * make GDL tests faster Co-authored-by: Rémi Flamary --- README.md | 2 + RELEASES.md | 2 +- .../plot_gromov_wasserstein_dictionary_learning.py | 357 +++++++ ot/__init__.py | 4 - ot/gromov.py | 1074 +++++++++++++++++++- test/test_gromov.py | 554 +++++++++- 6 files changed, 1954 insertions(+), 39 deletions(-) create mode 100755 examples/gromov/plot_gromov_wasserstein_dictionary_learning.py diff --git a/README.md b/README.md index a7627df..c6bfd9c 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. POT provides the following Machine Learning related solvers: @@ -198,6 +199,7 @@ The contributors to this library are * [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) * [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): diff --git a/RELEASES.md b/RELEASES.md index 4d05582..925920a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -10,7 +10,7 @@ of the regularization parameter (PR #336). - Backend implementation for `ot.lp.free_support_barycenter` (PR #340). - Add weak OT solver + example (PR #341). - +- Add (F)GW linear dictionary learning solvers + example (PR #319) #### Closed issues diff --git a/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py new file mode 100755 index 0000000..1fdc3b9 --- /dev/null +++ b/examples/gromov/plot_gromov_wasserstein_dictionary_learning.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- + +r""" +================================= +(Fused) Gromov-Wasserstein Linear Dictionary Learning +================================= + +In this exemple, we illustrate how to learn a Gromov-Wasserstein dictionary on +a dataset of structured data such as graphs, denoted +:math:`\{ \mathbf{C_s} \}_{s \in [S]}` where every nodes have uniform weights. +Given a dictionary :math:`\mathbf{C_{dict}}` composed of D structures of a fixed +size nt, each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` +is modeled as a convex combination :math:`\mathbf{w_s} \in \Sigma_D` of these +dictionary atoms as :math:`\sum_d w_{s,d} \mathbf{C_{dict}[d]}`. + + +First, we consider a dataset composed of graphs generated by Stochastic Block models +with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters +varying in :math:`\{ 1, 2, 3\}`. We learn a dictionary of 3 atoms, by minimizing +the Gromov-Wasserstein distance from all samples to its model in the dictionary +with respect to the dictionary atoms. + +Second, we illustrate the extension of this dictionary learning framework to +structured data endowed with node features by using the Fused Gromov-Wasserstein +distance. Starting from the aforementioned dataset of unattributed graphs, we +add discrete labels uniformly depending on the number of clusters. Then we learn +and visualize attributed graph atoms where each sample is modeled as a joint convex +combination between atom structures and features. + + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph +Dictionary Learning, International Conference on Machine Learning (ICML), 2021. + +""" +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +from sklearn.manifold import MDS +from ot.gromov import gromov_wasserstein_linear_unmixing, gromov_wasserstein_dictionary_learning, fused_gromov_wasserstein_linear_unmixing, fused_gromov_wasserstein_dictionary_learning +import ot +import networkx +from networkx.generators.community import stochastic_block_model as sbm +# %% +# ============================================================================= +# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. +# ============================================================================= + +np.random.seed(42) + +N = 60 # number of graphs in the dataset +# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability. +clusters = [1, 2, 3] +Nc = N // len(clusters) # number of graphs by cluster +nlabels = len(clusters) +dataset = [] +labels = [] + +p_inter = 0.1 +p_intra = 0.9 +for n_cluster in clusters: + for i in range(Nc): + n_nodes = int(np.random.uniform(low=30, high=50)) + + if n_cluster > 1: + P = p_inter * np.ones((n_cluster, n_cluster)) + np.fill_diagonal(P, p_intra) + else: + P = p_intra * np.eye(1) + sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32) + G = sbm(sizes, P, seed=i, directed=False) + C = networkx.to_numpy_array(G) + dataset.append(C) + labels.append(n_cluster) + + +# Visualize samples + +def plot_graph(x, C, binary=True, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if binary: + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + else: # connection intensity proportional to C[i,j] + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k') + + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color='C0', s=50.) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Estimate the gromov-wasserstein dictionary from the dataset +# ============================================================================= + + +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] + +D = 3 # 3 atoms in the dictionary +nt = 6 # of 6 nodes each + +q = ot.unif(nt) +reg = 0. # regularization coefficient to promote sparsity of unmixings {w_s} + +Cdict_GW, log = gromov_wasserstein_dictionary_learning( + Cs=dataset, D=D, nt=nt, ps=ps, q=q, epochs=10, batch_size=16, + learning_rate=0.1, reg=reg, projection='nonnegative_symmetric', + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, + use_log=True, use_adam_optimizer=True, verbose=True +) +# visualize loss evolution over epochs +pl.figure(2, (4, 3)) +pl.clf() +pl.title('loss evolution by epoch', fontsize=14) +pl.plot(log['loss_epochs']) +pl.xlabel('epochs', fontsize=12) +pl.ylabel('loss', fontsize=12) +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the estimated dictionary atoms +# ============================================================================= + + +# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white) + +pl.figure(3, (12, 8)) +pl.clf() +for idx_atom, atom in enumerate(Cdict_GW): + scaled_atom = (atom - atom.min()) / (atom.max() - atom.min()) + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + pl.subplot(2, D, idx_atom + 1) + pl.title('(graph) atom ' + str(idx_atom + 1), fontsize=14) + plot_graph(x, atom / atom.max(), binary=False, color='C0', s=100.) + pl.axis("off") + pl.subplot(2, D, D + idx_atom + 1) + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation='nearest') + pl.colorbar() + pl.axis("off") +pl.tight_layout() +pl.show() +#%% +# ============================================================================= +# Visualization of the embedding space +# ============================================================================= + +unmixings = [] +reconstruction_errors = [] +for C in dataset: + p = ot.unif(C.shape[0]) + unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing( + C, Cdict_GW, p=p, q=q, reg=reg, + tol_outer=10**(-5), tol_inner=10**(-5), + max_iter_outer=30, max_iter_inner=300 + ) + unmixings.append(unmixing) + reconstruction_errors.append(reconstruction_error) +unmixings = np.array(unmixings) +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) + + +# Compute the 2D representation of the unmixing living in the 2-simplex of probability +unmixings2D = np.zeros(shape=(N, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(4, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() +# %% +# ============================================================================= +# Endow the dataset with node features +# ============================================================================= + +# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters +# 1 cluster --> 0 as nodes feature +# 2 clusters --> 1 as nodes feature +# 3 clusters --> 2 as nodes feature +# features are one-hot encoded following these assignments +dataset_features = [] +for i in range(len(dataset)): + n = dataset[i].shape[0] + F = np.zeros((n, 3)) + if i < Nc: # graph with 1 cluster + F[:, 0] = 1. + elif i < 2 * Nc: # graph with 2 clusters + F[:, 1] = 1. + else: # graph with 3 clusters + F[:, 2] = 1. + dataset_features.append(F) + +pl.figure(5, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + F = dataset_features[(c - 1) * Nc] + colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])] + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color=colors, s=50) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() +# %% +# ============================================================================= +# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs +# ============================================================================= +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] +D = 3 # 6 atoms instead of 3 +nt = 6 +q = ot.unif(nt) +reg = 0.001 +alpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein + + +Cdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning( + Cs=dataset, Ys=dataset_features, D=D, nt=nt, ps=ps, q=q, alpha=alpha, + epochs=10, batch_size=16, learning_rate_C=0.1, learning_rate_Y=0.1, reg=reg, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=30, max_iter_inner=300, + projection='nonnegative_symmetric', use_log=True, use_adam_optimizer=True, verbose=True +) +# visualize loss evolution +pl.figure(6, (4, 3)) +pl.clf() +pl.title('loss evolution by epoch', fontsize=14) +pl.plot(log['loss_epochs']) +pl.xlabel('epochs', fontsize=12) +pl.ylabel('loss', fontsize=12) +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the estimated dictionary atoms +# ============================================================================= + +pl.figure(7, (12, 8)) +pl.clf() +max_features = Ydict_FGW.max() +min_features = Ydict_FGW.min() + +for idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)): + scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min()) + #scaled_F = 2 * (Fatom - min_features) / (max_features - min_features) + colors = ['C%s' % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])] + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - scaled_atom) + pl.subplot(2, D, idx_atom + 1) + pl.title('(attributed graph) atom ' + str(idx_atom + 1), fontsize=14) + plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100) + pl.axis("off") + pl.subplot(2, D, D + idx_atom + 1) + pl.title('(matrix) atom %s \n' % (idx_atom + 1), fontsize=14) + pl.imshow(scaled_atom, interpolation='nearest') + pl.colorbar() + pl.axis("off") +pl.tight_layout() +pl.show() + +# %% +# ============================================================================= +# Visualization of the embedding space +# ============================================================================= + +unmixings = [] +reconstruction_errors = [] +for i in range(len(dataset)): + C = dataset[i] + Y = dataset_features[i] + p = ot.unif(C.shape[0]) + unmixing, Cembedded, Yembedded, OT, reconstruction_error = fused_gromov_wasserstein_linear_unmixing( + C, Y, Cdict_FGW, Ydict_FGW, p=p, q=q, alpha=alpha, + reg=reg, tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=30, max_iter_inner=300 + ) + unmixings.append(unmixing) + reconstruction_errors.append(reconstruction_error) +unmixings = np.array(unmixings) +print('cumulated reconstruction error:', np.array(reconstruction_errors).sum()) + +# Visualize unmixings in the 2-simplex of probability +unmixings2D = np.zeros(shape=(N, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(8, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=40., label='%s clusters' % (cluster + 1)) + +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=80., label='atoms') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 7253318..bda7a35 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -1,5 +1,4 @@ """ - .. warning:: The list of automatically imported sub-modules is as follows: :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` @@ -7,13 +6,10 @@ :py:mod:`ot.gromov`, :py:mod:`ot.smooth` :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath` , :py:mod:`ot.unbalanced`. - The following sub-modules are not imported due to additional dependencies: - - :any:`ot.dr` : depends on :code:`pymanopt` and :code:`autograd`. - :any:`ot.gpu` : depends on :code:`cupy` and a CUDA GPU. - :any:`ot.plot` : depends on :code:`matplotlib` - """ # Author: Remi Flamary diff --git a/ot/gromov.py b/ot/gromov.py index b7e7949..f5a1f91 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -7,6 +7,7 @@ Gromov-Wasserstein and Fused-Gromov-Wasserstein solvers # Nicolas Courty # Rémi Flamary # Titouan Vayer +# Cédric Vincent-Cuaz # # License: MIT License @@ -17,7 +18,7 @@ from .bregman import sinkhorn from .utils import dist, UndefinedParameter, list_to_array from .optim import cg from .lp import emd_1d, emd -from .utils import check_random_state +from .utils import check_random_state, unif from .backend import get_backend @@ -320,7 +321,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) -def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -365,6 +366,9 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -389,18 +393,26 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F """ p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - nx = get_backend(p0, q0, C10, C20) - + if G0 is None: + nx = get_backend(p0, q0, C10, C20) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) - constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) - G0 = p[:, None] * q[None, :] + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) def f(G): return gwloss(constC, hC1, hC2, G) @@ -418,7 +430,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, G0=None, **kwargs): r""" Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -467,6 +479,9 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. Returns ------- @@ -491,9 +506,12 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= """ p, q = list_to_array(p, q) - p0, q0, C10, C20 = p, q, C1, C2 - nx = get_backend(p0, q0, C10, C20) + if G0 is None: + nx = get_backend(p0, q0, C10, C20) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) @@ -502,7 +520,13 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - G0 = p[:, None] * q[None, :] + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) def f(G): return gwloss(constC, hC1, hC2, G) @@ -533,7 +557,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= return gw -def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): +def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): r""" Computes the FGW transport between two graphs (see :ref:`[24] `) @@ -578,6 +602,9 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. log : bool, optional record log if True **kwargs : dict @@ -600,20 +627,28 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, (ICML). 2019. """ p, q = list_to_array(p, q) - p0, q0, C10, C20, M0 = p, q, C1, C2, M - nx = get_backend(p0, q0, C10, C20, M0) + if G0 is None: + nx = get_backend(p0, q0, C10, C20, M0) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, M0, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) M = nx.to_numpy(M0) + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - G0 = p[:, None] * q[None, :] - def f(G): return gwloss(constC, hC1, hC2, G) @@ -622,19 +657,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) - log['fgw_dist'] = fgw_dist log['u'] = nx.from_numpy(log['u'], type_as=C10) log['v'] = nx.from_numpy(log['v'], type_as=C10) return nx.from_numpy(res, type_as=C10), log - else: return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) -def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): +def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, G0=None, log=False, **kwargs): r""" Computes the FGW distance between two graphs see (see :ref:`[24] `) @@ -683,6 +715,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 armijo : bool, optional If True the step of the line-search is found via an armijo research. Else closed form is used. If there are convergence issues use False. + G0: array-like, shape (ns,nt), optional + If None the initial transport plan of the solver is pq^T. + Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. log : bool, optional Record log if True. **kwargs : dict @@ -711,7 +746,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 p, q = list_to_array(p, q) p0, q0, C10, C20, M0 = p, q, C1, C2, M - nx = get_backend(p0, q0, C10, C20, M0) + if G0 is None: + nx = get_backend(p0, q0, C10, C20, M0) + else: + G0_ = G0 + nx = get_backend(p0, q0, C10, C20, M0, G0_) p = nx.to_numpy(p) q = nx.to_numpy(q) @@ -721,7 +760,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) - G0 = p[:, None] * q[None, :] + if G0 is None: + G0 = p[:, None] * q[None, :] + else: + G0 = nx.to_numpy(G0_) + # Check marginals of G0 + np.testing.assert_allclose(G0.sum(axis=1), p, atol=1e-08) + np.testing.assert_allclose(G0.sum(axis=0), q, atol=1e-08) def f(G): return gwloss(constC, hC1, hC2, G) @@ -1796,3 +1841,988 @@ def update_feature_matrix(lambdas, Ys, Ts, p): for s in range(len(Ts)) ]) return tmpsum + + +def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}}, \{\mathbf{w_s} \}_{s \leq S}} \sum_{s=1}^S GW_2(\mathbf{C_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) - reg\| \mathbf{w_s} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - reg is the regularization coefficient. + + The stochastic algorithm used for estimating the graph dictionary atoms as proposed in [38] + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns, ns). + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate: float, optional + Learning rate used for the stochastic gradient descent. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + projection: str , optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epochs. + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + # Handle backend of non-optional arguments + Cs0 = Cs + nx = get_backend(*Cs0) + Cs = [nx.to_numpy(C) for C in Cs0] + dataset_size = len(Cs) + # Handle backend of optional arguments + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0 + if use_adam_optimizer: + adam_moments = _initialize_adam_optimizer(Cdict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + Cdict_best_state = Cdict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + # batch sampling + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ts[batch_idx], current_loss = gromov_wasserstein_linear_unmixing( + Cs[C_idx], Cdict, reg=reg, p=ps[C_idx], q=q, tol_outer=tol_outer, tol_inner=tol_inner, + max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + grad_Cdict += unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Cdict *= 2 / batch_size + if use_adam_optimizer: + Cdict, adam_moments = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate, adam_moments) + else: + Cdict -= learning_rate * grad_Cdict + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + if verbose: + print('--- epoch =', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), log + + +def _initialize_adam_optimizer(variable): + + # Initialization for our numpy implementation of adam optimizer + atoms_adam_m = np.zeros_like(variable) # Initialize first moment tensor + atoms_adam_v = np.zeros_like(variable) # Initialize second moment tensor + atoms_adam_count = 1 + + return {'mean': atoms_adam_m, 'var': atoms_adam_v, 'count': atoms_adam_count} + + +def _adam_stochastic_updates(variable, grad, learning_rate, adam_moments, beta_1=0.9, beta_2=0.99, eps=1e-09): + + adam_moments['mean'] = beta_1 * adam_moments['mean'] + (1 - beta_1) * grad + adam_moments['var'] = beta_2 * adam_moments['var'] + (1 - beta_2) * (grad**2) + unbiased_m = adam_moments['mean'] / (1 - beta_1**adam_moments['count']) + unbiased_v = adam_moments['var'] / (1 - beta_2**adam_moments['count']) + variable -= learning_rate * unbiased_m / (np.sqrt(unbiased_v) + eps) + adam_moments['count'] += 1 + + return variable, adam_moments + + +def gromov_wasserstein_linear_unmixing(C, Cdict, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): + r""" + Returns the Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}`. + + .. math:: + \min_{ \mathbf{w}} GW_2(\mathbf{C}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of size nt. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 1. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + gromov-wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + T: array-like (ns, nt) + Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \mathbf{q})` + current_loss: float + reconstruction error + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + C0, Cdict0 = C, Cdict + nx = get_backend(C0, Cdict0) + C = nx.to_numpy(C0) + Cdict = nx.to_numpy(Cdict0) + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + + w = unif(D) # Initialize uniformly the unmixing w + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + + const_q = q[:, None] * q[None, :] + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + T, log = gromov_wasserstein(C1=C, C2=Cembedded, p=p, q=q, loss_fun='square_loss', G0=T, log=True, armijo=False, **kwargs) + current_loss = log['gw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, current_loss = _cg_gromov_wasserstein_unmixing( + C=C, Cdict=Cdict, Cembedded=Cembedded, w=w, const_q=const_q, T=T, + starting_loss=current_loss, reg=reg, tol=tol_inner, max_iter=max_iter_inner, **kwargs + ) + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_gromov_wasserstein_unmixing(C, Cdict, Cembedded, w, const_q, T, starting_loss, reg=0., tol=10**(-5), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the linear unmixing w minimizing the Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w[d]*\mathbf{C_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d*C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg* \| \mathbf{w} \|_2^2 + + + Such that: + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is the (ns,ns) pairwise similarity matrix. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrices of nt points. + - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights. + - :math:`\mathbf{w}` is the linear unmixing of :math:`(\mathbf{C}, \mathbf{p})` onto :math:`(\sum_d w_d \mathbf{Cdict[d]}, \mathbf{q})`. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the current state of :math:`\mathbf{w}`. + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38] + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w[d]*Cdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + w: array-like, shape (D,) + Linear unmixing of the input structure onto the dictionary + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + T: array-like, shape (ns,nt) + fixed transport plan between the input structure and its representation in the dictionary. + p : array-like, shape (ns,) + Distribution in the source space. + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of w. Default is 0. + + Returns + ------- + w: ndarray (D,) + optimal unmixing of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary span given OT starting from previously optimal unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + + while (convergence_criterion > tol) and (count < max_iter): + + previous_loss = current_loss + # 1) Compute gradient at current point w + grad_w = 2 * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + grad_w -= 2 * reg * w + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff = _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: # not that the loss can be negative if reg >0 + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: # handle numerical issues around 0 + convergence_criterion = abs(previous_loss - current_loss) / 10**(-15) + count += 1 + + return w, Cembedded, current_loss + + +def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, const_q, const_TCT, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that: + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed C. + Each matrix in the dictionary must have the same size (nt,nt). + Cembedded: array-like, shape (nt,nt) + Embedded structure :math:`(\sum_d w_dCdict[d],q)` of :math:`(\mathbf{C},\mathbf{p})` onto the dictionary. Used to avoid redundant computations. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + reg : float, optional. + Coefficient of the negative quadratic regularization used to promote sparsity of :math:`\mathbf{w}`. + """ + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + a = trace_diffx - trace_diffw + b = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + + if a > 0: + gamma = min(1, max(0, - b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff + + +def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1., + Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False, + tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs): + r""" + Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s` + + .. math:: + \min_{\mathbf{C_{dict}},\mathbf{Y_{dict}}, \{\mathbf{w_s}\}_{s}} \sum_{s=1}^S FGW_{2,\alpha}(\mathbf{C_s}, \mathbf{Y_s}, \sum_{d=1}^D w_{s,d}\mathbf{C_{dict}[d]},\sum_{d=1}^D w_{s,d}\mathbf{Y_{dict}[d]}, \mathbf{p_s}, \mathbf{q}) \\ - reg\| \mathbf{w_s} \|_2^2 + + + Such that :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\forall s \leq S, \mathbf{C_s}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\forall s \leq S, \mathbf{Y_s}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\forall s \leq S, \mathbf{p_s}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + + The stochastic algorithm used for estimating the attributed graph dictionary atoms as proposed in [38] + + Parameters + ---------- + Cs : list of S symmetric array-like, shape (ns, ns) + List of Metric/Graph cost matrices of variable size (ns,ns). + Ys : list of S array-like, shape (ns, d) + List of feature matrix of variable size (ns,d) with d fixed. + D: int + Number of dictionary atoms to learn + nt: int + Number of samples within each dictionary atoms + alpha : float + Trade-off parameter of Fused Gromov-Wasserstein + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + ps : list of S array-like, shape (ns,), optional + Distribution in each source space C of Cs. Default is None and corresponds to uniform distibutions. + q : array-like, shape (nt,), optional + Distribution in the embedding space whose structure will be learned. Default is None and corresponds to uniform distributions. + epochs: int, optional + Number of epochs used to learn the dictionary. Default is 32. + batch_size: int, optional + Batch size for each stochastic gradient update of the dictionary. Set to the dataset size if the provided batch_size is higher than the dataset size. Default is 32. + learning_rate_C: float, optional + Learning rate used for the stochastic gradient descent on Cdict. Default is 1. + learning_rate_Y: float, optional + Learning rate used for the stochastic gradient descent on Ydict. Default is 1. + Cdict_init: list of D array-like with shape (nt, nt), optional + Used to initialize the dictionary structures Cdict. + If set to None (Default), the dictionary will be initialized randomly. + Else Cdict must have shape (D, nt, nt) i.e match provided shape features. + Ydict_init: list of D array-like with shape (nt, d), optional + Used to initialize the dictionary features Ydict. + If set to None, the dictionary features will be initialized randomly. + Else Ydict must have shape (D, nt, d) where d is the features dimension of inputs Ys and also match provided shape features. + projection: str, optional + If 'nonnegative' and/or 'symmetric' is in projection, the corresponding projection will be performed at each stochastic update of the dictionary + Else the set of atoms is :math:`R^{nt * nt}`. Default is 'nonnegative_symmetric' + log: bool, optional + If set to True, losses evolution by batches and epochs are tracked. Default is False. + use_adam_optimizer: bool, optional + If set to True, adam optimizer with default settings is used as adaptative learning rate strategy. + Else perform SGD with fixed learning rate. Default is True. + tol_outer : float, optional + Solver precision for the BCD algorithm, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport, measured by absolute relative error on consecutive losses. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + verbose : bool, optional + Print the reconstruction loss every epoch. Default is False. + + Returns + ------- + + Cdict_best_state : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + Ydict_best_state : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary. + The dictionary leading to the best loss over an epoch is saved and returned. + log: dict + If use_log is True, contains loss evolutions by batches and epoches. + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + Cs0, Ys0 = Cs, Ys + nx = get_backend(*Cs0, *Ys0) + Cs = [nx.to_numpy(C) for C in Cs0] + Ys = [nx.to_numpy(Y) for Y in Ys0] + + d = Ys[0].shape[-1] + dataset_size = len(Cs) + + if ps is None: + ps = [unif(C.shape[0]) for C in Cs] + else: + ps = [nx.to_numpy(p) for p in ps] + if q is None: + q = unif(nt) + else: + q = nx.to_numpy(q) + + if Cdict_init is None: + # Initialize randomly structures of dictionary atoms based on samples + dataset_means = [C.mean() for C in Cs] + Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt)) + else: + Cdict = nx.to_numpy(Cdict_init).copy() + assert Cdict.shape == (D, nt, nt) + if Ydict_init is None: + # Initialize randomly features of dictionary atoms based on samples distribution by feature component + dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys]) + Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d)) + else: + Ydict = nx.to_numpy(Ydict_init).copy() + assert Ydict.shape == (D, nt, d) + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_adam_optimizer: + adam_moments_C = _initialize_adam_optimizer(Cdict) + adam_moments_Y = _initialize_adam_optimizer(Ydict) + + log = {'loss_batches': [], 'loss_epochs': []} + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + loss_best_state = np.inf + if batch_size > dataset_size: + batch_size = dataset_size + iter_by_epoch = dataset_size // batch_size + int((dataset_size % batch_size) > 0) + + for epoch in range(epochs): + cumulated_loss_over_epoch = 0. + + for _ in range(iter_by_epoch): + + # Batch iterations + batch = np.random.choice(range(dataset_size), size=batch_size, replace=False) + cumulated_loss_over_batch = 0. + unmixings = np.zeros((batch_size, D)) + Cs_embedded = np.zeros((batch_size, nt, nt)) + Ys_embedded = np.zeros((batch_size, nt, d)) + Ts = [None] * batch_size + + for batch_idx, C_idx in enumerate(batch): + # BCD solver for Gromov-Wassersteisn linear unmixing used independently on each structure of the sampled batch + unmixings[batch_idx], Cs_embedded[batch_idx], Ys_embedded[batch_idx], Ts[batch_idx], current_loss = fused_gromov_wasserstein_linear_unmixing( + Cs[C_idx], Ys[C_idx], Cdict, Ydict, alpha, reg=reg, p=ps[C_idx], q=q, + tol_outer=tol_outer, tol_inner=tol_inner, max_iter_outer=max_iter_outer, max_iter_inner=max_iter_inner + ) + cumulated_loss_over_batch += current_loss + cumulated_loss_over_epoch += cumulated_loss_over_batch + if use_log: + log['loss_batches'].append(cumulated_loss_over_batch) + + # Stochastic projected gradient step over dictionary atoms + grad_Cdict = np.zeros_like(Cdict) + grad_Ydict = np.zeros_like(Ydict) + + for batch_idx, C_idx in enumerate(batch): + shared_term_structures = Cs_embedded[batch_idx] * const_q - (Cs[C_idx].dot(Ts[batch_idx])).T.dot(Ts[batch_idx]) + shared_term_features = diag_q.dot(Ys_embedded[batch_idx]) - Ts[batch_idx].T.dot(Ys[C_idx]) + grad_Cdict += alpha * unmixings[batch_idx][:, None, None] * shared_term_structures[None, :, :] + grad_Ydict += (1 - alpha) * unmixings[batch_idx][:, None, None] * shared_term_features[None, :, :] + grad_Cdict *= 2 / batch_size + grad_Ydict *= 2 / batch_size + + if use_adam_optimizer: + Cdict, adam_moments_C = _adam_stochastic_updates(Cdict, grad_Cdict, learning_rate_C, adam_moments_C) + Ydict, adam_moments_Y = _adam_stochastic_updates(Ydict, grad_Ydict, learning_rate_Y, adam_moments_Y) + else: + Cdict -= learning_rate_C * grad_Cdict + Ydict -= learning_rate_Y * grad_Ydict + + if 'symmetric' in projection: + Cdict = 0.5 * (Cdict + Cdict.transpose((0, 2, 1))) + if 'nonnegative' in projection: + Cdict[Cdict < 0.] = 0. + + if use_log: + log['loss_epochs'].append(cumulated_loss_over_epoch) + if loss_best_state > cumulated_loss_over_epoch: + loss_best_state = cumulated_loss_over_epoch + Cdict_best_state = Cdict.copy() + Ydict_best_state = Ydict.copy() + if verbose: + print('--- epoch: ', epoch, ' cumulated reconstruction error: ', cumulated_loss_over_epoch) + + return nx.from_numpy(Cdict_best_state), nx.from_numpy(Ydict_best_state), log + + +def fused_gromov_wasserstein_linear_unmixing(C, Y, Cdict, Ydict, alpha, reg=0., p=None, q=None, tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, **kwargs): + r""" + Returns the Fused Gromov-Wasserstein linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the attributed dictionary atoms :math:`\{ (\mathbf{C_{dict}[d]},\mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` + + .. math:: + \min_{\mathbf{w}} FGW_{2,\alpha}(\mathbf{C},\mathbf{Y}, \sum_{d=1}^D w_d\mathbf{C_{dict}[d]},\sum_{d=1}^D w_d\mathbf{Y_{dict}[d]}, \mathbf{p}, \mathbf{q}) - reg \| \mathbf{w} \|_2^2 + + such that, :math:`\forall s \leq S` : + + - :math:`\mathbf{w_s}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w_s} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Block Coordinate Descent as discussed in [38], algorithm 6. + + Parameters + ---------- + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : D array-like, shape (D,nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Ydict : D array-like, shape (D,nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. The default is 0. + p : array-like, shape (ns,), optional + Distribution in the source space C. Default is None and corresponds to uniform distribution. + q : array-like, shape (nt,), optional + Distribution in the space depicted by the dictionary. Default is None and corresponds to uniform distribution. + tol_outer : float, optional + Solver precision for the BCD algorithm. + tol_inner : float, optional + Solver precision for the Conjugate Gradient algorithm used to get optimal w at a fixed transport. Default is :math:`10^{-5}`. + max_iter_outer : int, optional + Maximum number of iterations for the BCD. Default is 20. + max_iter_inner : int, optional + Maximum number of iterations for the Conjugate Gradient. Default is 200. + + Returns + ------- + w: array-like, shape (D,) + fused gromov-wasserstein linear unmixing of (C,Y,p) onto the span of the dictionary. + Cembedded: array-like, shape (nt,nt) + embedded structure of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{C_{dict}[d]}`. + Yembedded: array-like, shape (nt,d) + embedded features of :math:`(\mathbf{C},\mathbf{Y}, \mathbf{p})` onto the dictionary, :math:`\sum_d w_d\mathbf{Y_{dict}[d]}`. + T: array-like (ns,nt) + Fused Gromov-Wasserstein transport plan between :math:`(\mathbf{C},\mathbf{p})` and :math:`(\sum_d w_d\mathbf{C_{dict}[d]}, \sum_d w_d\mathbf{Y_{dict}[d]},\mathbf{q})`. + current_loss: float + reconstruction error + References + ------- + + ..[38] Cédric Vincent-Cuaz, Titouan Vayer, Rémi Flamary, Marco Corneli, Nicolas Courty. + "Online Graph Dictionary Learning" + International Conference on Machine Learning (ICML). 2021. + """ + C0, Y0, Cdict0, Ydict0 = C, Y, Cdict, Ydict + nx = get_backend(C0, Y0, Cdict0, Ydict0) + C = nx.to_numpy(C0) + Y = nx.to_numpy(Y0) + Cdict = nx.to_numpy(Cdict0) + Ydict = nx.to_numpy(Ydict0) + + if p is None: + p = unif(C.shape[0]) + else: + p = nx.to_numpy(p) + if q is None: + q = unif(Cdict.shape[-1]) + else: + q = nx.to_numpy(q) + + T = p[:, None] * q[None, :] + D = len(Cdict) + d = Y.shape[-1] + w = unif(D) # Initialize with uniform weights + ns = C.shape[-1] + nt = Cdict.shape[-1] + + # modeling (C,Y) + Cembedded = np.sum(w[:, None, None] * Cdict, axis=0) + Yembedded = np.sum(w[:, None, None] * Ydict, axis=0) + + # constants depending on q + const_q = q[:, None] * q[None, :] + diag_q = np.diag(q) + # Trackers for BCD convergence + convergence_criterion = np.inf + current_loss = 10**15 + outer_count = 0 + Ys_constM = (Y**2).dot(np.ones((d, nt))) # constant in computing euclidean pairwise feature matrix + + while (convergence_criterion > tol_outer) and (outer_count < max_iter_outer): + previous_loss = current_loss + + # 1. Solve GW transport between (C,p) and (\sum_d Cdictionary[d],q) fixing the unmixing w + Yt_varM = (np.ones((ns, d))).dot((Yembedded**2).T) + M = Ys_constM + Yt_varM - 2 * Y.dot(Yembedded.T) # euclidean distance matrix between features + T, log = fused_gromov_wasserstein(M, C, Cembedded, p, q, loss_fun='square_loss', alpha=alpha, armijo=False, G0=T, log=True) + current_loss = log['fgw_dist'] + if reg != 0: + current_loss -= reg * np.sum(w**2) + + # 2. Solve linear unmixing problem over w with a fixed transport plan T + w, Cembedded, Yembedded, current_loss = _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, + T, p, q, const_q, diag_q, current_loss, alpha, reg, + tol=tol_inner, max_iter=max_iter_inner, **kwargs) + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + outer_count += 1 + + return nx.from_numpy(w), nx.from_numpy(Cembedded), nx.from_numpy(Yembedded), nx.from_numpy(T), nx.from_numpy(current_loss) + + +def _cg_fused_gromov_wasserstein_unmixing(C, Y, Cdict, Ydict, Cembedded, Yembedded, w, T, p, q, const_q, diag_q, starting_loss, alpha, reg, tol=10**(-6), max_iter=200, **kwargs): + r""" + Returns for a fixed admissible transport plan, + the optimal linear unmixing :math:`\mathbf{w}` minimizing the Fused Gromov-Wasserstein cost between :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` and :math:`(\sum_d w_d \mathbf{C_{dict}[d]},\sum_d w_d*\mathbf{Y_{dict}[d]}, \mathbf{q})` + + .. math:: + \min_{\mathbf{w}} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D w_d C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\+ (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d w_d \mathbf{Y_{dict}[d]_j} \|_2^2 T_{ij}- reg \| \mathbf{w} \|_2^2 + + Such that : + + - :math:`\mathbf{w}^\top \mathbf{1}_D = 1` + - :math:`\mathbf{w} \geq \mathbf{0}_D` + + Where : + + - :math:`\mathbf{C}` is a (ns,ns) pairwise similarity matrix of variable size ns. + - :math:`\mathbf{Y}` is a (ns,d) features matrix of variable size ns and fixed dimension d. + - :math:`\mathbf{C_{dict}}` is a (D, nt, nt) tensor of D pairwise similarity matrix of fixed size nt. + - :math:`\mathbf{Y_{dict}}` is a (D, nt, d) tensor of D features matrix of fixed size nt and fixed dimension d. + - :math:`\mathbf{p}` is the source distribution corresponding to :math:`\mathbf{C_s}` + - :math:`\mathbf{q}` is the target distribution assigned to every structures in the embedding space. + - :math:`\mathbf{T}` is the optimal transport plan conditioned by the previous state of :math:`\mathbf{w}` + - :math:`\alpha` is the trade-off parameter of Fused Gromov-Wasserstein + - reg is the regularization coefficient. + + The algorithm used for solving the problem is a Conditional Gradient Descent as discussed in [38], algorithm 7. + + Parameters + ---------- + + C : array-like, shape (ns, ns) + Metric/Graph cost matrix. + Y : array-like, shape (ns, d) + Feature matrix. + Cdict : list of D array-like, shape (nt,nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt,d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt,nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt,d) + Embedded features of (C,Y) onto the dictionary + w: array-like, shape (n_D,) + Linear unmixing of (C,Y) onto (Cdict,Ydict) + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{qq}^\top` where :math:`\mathbf{q}` is the target space distribution. + diag_q: array-like, shape (nt,nt) + diagonal matrix with values of q on the diagonal. + T: array-like, shape (ns,nt) + fixed transport plan between (C,Y) and its model + p : array-like, shape (ns,) + Distribution in the source space (C,Y). + q : array-like, shape (nt,) + Distribution in the embedding space depicted by the dictionary. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + w: ndarray (D,) + linear unmixing of :math:`(\mathbf{C},\mathbf{Y},\mathbf{p})` onto the span of :math:`(C_{dict},Y_{dict})` given OT corresponding to previous unmixing. + """ + convergence_criterion = np.inf + current_loss = starting_loss + count = 0 + const_TCT = np.transpose(C.dot(T)).dot(T) + ones_ns_d = np.ones(Y.shape) + + while (convergence_criterion > tol) and (count < max_iter): + previous_loss = current_loss + + # 1) Compute gradient at current point w + # structure + grad_w = alpha * np.sum(Cdict * (Cembedded[None, :, :] * const_q[None, :, :] - const_TCT[None, :, :]), axis=(1, 2)) + # feature + grad_w += (1 - alpha) * np.sum(Ydict * (diag_q.dot(Yembedded)[None, :, :] - T.T.dot(Y)[None, :, :]), axis=(1, 2)) + grad_w -= reg * w + grad_w *= 2 + + # 2) Conditional gradient direction finding: x= \argmin_x x^T.grad_w + min_ = np.min(grad_w) + x = (grad_w == min_).astype(np.float64) + x /= np.sum(x) + + # 3) Line-search step: solve \argmin_{\gamma \in [0,1]} a*gamma^2 + b*gamma + c + gamma, a, b, Cembedded_diff, Yembedded_diff = _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg) + + # 4) Updates: w <-- (1-gamma)*w + gamma*x + w += gamma * (x - w) + Cembedded += gamma * Cembedded_diff + Yembedded += gamma * Yembedded_diff + current_loss += a * (gamma**2) + b * gamma + + if previous_loss != 0: + convergence_criterion = abs(previous_loss - current_loss) / abs(previous_loss) + else: + convergence_criterion = abs(previous_loss - current_loss) / 10**(-12) + count += 1 + + return w, Cembedded, Yembedded, current_loss + + +def _linesearch_fused_gromov_wasserstein_unmixing(w, grad_w, x, Y, Cdict, Ydict, Cembedded, Yembedded, T, const_q, const_TCT, ones_ns_d, alpha, reg, **kwargs): + r""" + Compute optimal steps for the line search problem of Fused Gromov-Wasserstein linear unmixing + .. math:: + \min_{\gamma \in [0,1]} \alpha \sum_{ijkl} (C_{i,j} - \sum_{d=1}^D z_d(\gamma)C_{dict}[d]_{k,l} )^2 T_{i,k}T_{j,l} \\ + (1-\alpha) \sum_{ij} \| \mathbf{Y_i} - \sum_d z_d(\gamma) \mathbf{Y_{dict}[d]_j} \|_2^2 - reg\| \mathbf{z}(\gamma) \|_2^2 + + + Such that : + + - :math:`\mathbf{z}(\gamma) = (1- \gamma)\mathbf{w} + \gamma \mathbf{x}` + + Parameters + ---------- + + w : array-like, shape (D,) + Unmixing. + grad_w : array-like, shape (D, D) + Gradient of the reconstruction loss with respect to w. + x: array-like, shape (D,) + Conditional gradient direction. + Y: arrat-like, shape (ns,d) + Feature matrix of the input space + Cdict : list of D array-like, shape (nt, nt) + Metric/Graph cost matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,nt). + Ydict : list of D array-like, shape (nt, d) + Feature matrices composing the dictionary on which to embed (C,Y). + Each matrix in the dictionary must have the same size (nt,d). + Cembedded: array-like, shape (nt, nt) + Embedded structure of (C,Y) onto the dictionary + Yembedded: array-like, shape (nt, d) + Embedded features of (C,Y) onto the dictionary + T: array-like, shape (ns, nt) + Fixed transport plan between (C,Y) and its current model. + const_q: array-like, shape (nt,nt) + product matrix :math:`\mathbf{q}\mathbf{q}^\top` where q is the target space distribution. Used to avoid redundant computations. + const_TCT: array-like, shape (nt, nt) + :math:`\mathbf{T}^\top \mathbf{C}^\top \mathbf{T}`. Used to avoid redundant computations. + ones_ns_d: array-like, shape (ns, d) + :math:`\mathbf{1}_{ ns \times d}`. Used to avoid redundant computations. + alpha: float, + Trade-off parameter of Fused Gromov-Wasserstein. + reg : float, optional + Coefficient of the negative quadratic regularization used to promote sparsity of w. + + Returns + ------- + gamma: float + Optimal value for the line-search step + a: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + b: float + Constant factor appearing in the factorization :math:`a \gamma^2 + b \gamma +c` of the reconstruction loss + Cembedded_diff: numpy array, shape (nt, nt) + Difference between structure matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + Yembedded_diff: numpy array, shape (nt, nt) + Difference between feature matrix of models evaluated in :math:`\mathbf{w}` and in :math:`\mathbf{w}`. + """ + # polynomial coefficients from quadratic objective (with respect to w) on structures + Cembedded_x = np.sum(x[:, None, None] * Cdict, axis=0) + Cembedded_diff = Cembedded_x - Cembedded + trace_diffx = np.sum(Cembedded_diff * Cembedded_x * const_q) + trace_diffw = np.sum(Cembedded_diff * Cembedded * const_q) + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_gw = trace_diffx - trace_diffw + b_gw = 2 * (trace_diffw - np.sum(Cembedded_diff * const_TCT)) + + # polynomial coefficient from quadratic objective (with respect to w) on features + Yembedded_x = np.sum(x[:, None, None] * Ydict, axis=0) + Yembedded_diff = Yembedded_x - Yembedded + # Constant factor appearing in the factorization a*gamma^2 + b*g + c of the Gromov-Wasserstein reconstruction loss + a_w = np.sum(ones_ns_d.dot((Yembedded_diff**2).T) * T) + b_w = 2 * np.sum(T * (ones_ns_d.dot((Yembedded * Yembedded_diff).T) - Y.dot(Yembedded_diff.T))) + + a = alpha * a_gw + (1 - alpha) * a_w + b = alpha * b_gw + (1 - alpha) * b_w + if reg != 0: + a -= reg * np.sum((x - w)**2) + b -= 2 * reg * np.sum(w * (x - w)) + if a > 0: + gamma = min(1, max(0, -b / (2 * a))) + elif a + b < 0: + gamma = 1 + else: + gamma = 0 + + return gamma, a, b, Cembedded_diff, Yembedded_diff diff --git a/test/test_gromov.py b/test/test_gromov.py index 4b995d5..329f99c 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -3,6 +3,7 @@ # Author: Erwan Vautier # Nicolas Courty # Titouan Vayer +# Cédric Vincent-Cuaz # # License: MIT License @@ -26,6 +27,7 @@ def test_gromov(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -37,9 +39,10 @@ def test_gromov(nx): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0) - G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)) + G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)) # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) @@ -56,9 +59,9 @@ def test_gromov(nx): gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True) gwb = nx.to_numpy(gwb) - gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False) + gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', G0=G0, log=False) gw_valb = nx.to_numpy( - ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) ) G = log['T'] @@ -91,6 +94,7 @@ def test_gromov_dtype_device(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -105,9 +109,10 @@ def test_gromov_dtype_device(nx): C2b = nx.from_numpy(C2, type_as=tp) pb = nx.from_numpy(p, type_as=tp) qb = nx.from_numpy(q, type_as=tp) + G0b = nx.from_numpy(G0, type_as=tp) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -123,6 +128,7 @@ def test_gromov_device_tf(): xt = xs[::-1].copy() p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) C1 /= C1.max() @@ -134,8 +140,9 @@ def test_gromov_device_tf(): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) - Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) - gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) + G0b = nx.from_numpy(G0) + Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) + gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) nx.assert_same_dtype_device(C1b, gw_valb) @@ -145,6 +152,7 @@ def test_gromov_device_tf(): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0b) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) nx.assert_same_dtype_device(C1b, Gb) @@ -554,6 +562,7 @@ def test_fgw(nx): p = ot.unif(n_samples) q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] C1 = ot.dist(xs, xs) C2 = ot.dist(xt, xt) @@ -569,9 +578,10 @@ def test_fgw(nx): C2b = nx.from_numpy(C2) pb = nx.from_numpy(p) qb = nx.from_numpy(q) + G0b = nx.from_numpy(G0) - G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True) Gb = nx.to_numpy(Gb) # check constraints @@ -586,8 +596,8 @@ def test_fgw(nx): np.testing.assert_allclose( Gb, np.flipud(Id), atol=1e-04) # cf convergence gromov - fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True) + fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', G0=None, alpha=0.5, log=True) + fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', G0=G0b, alpha=0.5, log=True) fgwb = nx.to_numpy(fgwb) G = log['T'] @@ -698,3 +708,523 @@ def test_fgw_barycenter(nx): Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) + + +def test_gromov_wasserstein_linear_unmixing(nx): + n = 10 + + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + p = ot.unif(n) + + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + Cdictb = nx.from_numpy(Cdict) + pb = nx.from_numpy(p) + tol = 10**(-5) + # Tests without regularization + reg = 0. + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + + reg = 0.001 + unmixing1, C1_emb, OT, reconstruction1 = ot.gromov.gromov_wasserstein_linear_unmixing( + C1, Cdict, reg=reg, p=p, q=p, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, OTb, reconstruction1b = ot.gromov.gromov_wasserstein_linear_unmixing( + C1b, Cdictb, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, OT, reconstruction2 = ot.gromov.gromov_wasserstein_linear_unmixing( + C2, Cdict, reg=reg, p=None, q=None, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, OTb, reconstruction2b = ot.gromov.gromov_wasserstein_linear_unmixing( + C2b, Cdictb, reg=reg, p=pb, q=pb, + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 10 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_means = [C.mean() for C in Cs] + np.random.seed(0) + Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + Csb = [nx.from_numpy(C) for C in Cs] + psb = [nx.from_numpy(p) for p in ps] + qb = nx.from_numpy(q) + Cdict_initb = nx.from_numpy(Cdict_init) + + # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization + # > Compute initial reconstruction of samples on this random dictionary without backend + use_adam_optimizer = True + verbose = False + tol = 10**(-5) + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_init, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + initial_total_reconstruction += reconstruction + + # > Learn the dictionary using this init + Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, + epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary without backend + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction += reconstruction + + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + + Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b += reconstruction + + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # knowing than the initialization scheme is the same than implemented to set the benchmarked initialization. + np.random.seed(0) + Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # Test: Same after going through backend + np.random.seed(0) + Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis, p=None, q=None, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03) + + # Test: Perform same comparison without providing the initial dictionary being an optional input + # and testing other optimization settings untested until now. + # We pass previously estimated dictionaries to speed up the process. + use_adam_optimizer = False + verbose = True + use_log = True + + np.random.seed(0) + Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # Test: Same after going through backend + np.random.seed(0) + Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( + Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb, + epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( + Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis2 += reconstruction + + np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05) + + +def test_fused_gromov_wasserstein_linear_unmixing(nx): + + n = 10 + X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cdict = np.stack([C1, C2]) + Ydict = np.stack([F, F]) + p = ot.unif(n) + + C1b = nx.from_numpy(C1) + C2b = nx.from_numpy(C2) + Fb = nx.from_numpy(F) + Cdictb = nx.from_numpy(Cdict) + Ydictb = nx.from_numpy(Ydict) + pb = nx.from_numpy(p) + # Tests without regularization + reg = 0. + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + # Tests with regularization + reg = 0.001 + + unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + ) + + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) + np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) + np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) + np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) + np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) + np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(C1b_emb.shape, (n, n)) + np.testing.assert_allclose(C2b_emb.shape, (n, n)) + + +def test_fused_gromov_wasserstein_dictionary_learning(nx): + + # create dataset composed from 2 structures which are repeated 5 times + shape = 10 + n_samples = 2 + n_atoms = 2 + projection = 'nonnegative_symmetric' + X1, y1 = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + X2, y2 = ot.datasets.make_data_classif('3gauss2', shape, random_state=42) + F, y = ot.datasets.make_data_classif('3gauss', shape, random_state=42) + + C1 = ot.dist(X1) + C2 = ot.dist(X2) + Cs = [C1.copy() for _ in range(n_samples // 2)] + [C2.copy() for _ in range(n_samples // 2)] + Ys = [F.copy() for _ in range(n_samples)] + ps = [ot.unif(shape) for _ in range(n_samples)] + q = ot.unif(shape) + + # Provide initialization for the graph dictionary of shape (n_atoms, shape, shape) + # following the same procedure than implemented in gromov_wasserstein_dictionary_learning. + dataset_structure_means = [C.mean() for C in Cs] + np.random.seed(0) + Cdict_init = np.random.normal(loc=np.mean(dataset_structure_means), scale=np.std(dataset_structure_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': + Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) + Cdict_init[Cdict_init < 0.] = 0. + dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys]) + Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2)) + + Csb = [nx.from_numpy(C) for C in Cs] + Ysb = [nx.from_numpy(Y) for Y in Ys] + psb = [nx.from_numpy(p) for p in ps] + qb = nx.from_numpy(q) + Cdict_initb = nx.from_numpy(Cdict_init) + Ydict_initb = nx.from_numpy(Ydict_init) + + # Test: Compute initial reconstruction of samples on this random dictionary + alpha = 0.5 + use_adam_optimizer = True + verbose = False + tol = 1e-05 + epochs = 1 + + initial_total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q, + alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + initial_total_reconstruction += reconstruction + + # > Learn a dictionary using this given initialization and check that the reconstruction loss + # on the learned dictionary is lower than the one using its initialization. + Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction += reconstruction + # Compare both + np.testing.assert_array_less(total_reconstruction, initial_total_reconstruction) + + # Test: Perform same experiments after going through backend + + Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb, + epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b += reconstruction + + np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) + np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) + np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) + np.testing.assert_allclose(Ydict, nx.to_numpy(Ydictb), atol=1e-03) + + # Test: Perform similar experiment without providing the initial dictionary being an optional input + np.random.seed(0) + Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis += reconstruction + + np.testing.assert_allclose(total_reconstruction_bis, total_reconstruction, atol=1e-05) + + # > Same after going through backend + np.random.seed(0) + Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, + epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis += reconstruction + np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) + + # Test: without using adam optimizer, with log and verbose set to True + use_adam_optimizer = False + verbose = True + use_log = True + + # > Experiment providing previously estimated dictionary to speed up the test compared to providing initial random init. + np.random.seed(0) + Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_bis2 += reconstruction + + np.testing.assert_array_less(total_reconstruction_bis2, total_reconstruction) + + # > Same after going through backend + np.random.seed(0) + Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( + Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb, + epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose + ) + + # > Compute reconstruction of samples on learned dictionary + total_reconstruction_b_bis2 = 0 + for i in range(n_samples): + _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( + Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0., + tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + ) + total_reconstruction_b_bis2 += reconstruction + + # > Compare results with/without backend + np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) -- cgit v1.2.3 From 3302cd48cdcc5d4832997bae921952cc3917fb59 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 23 Feb 2022 09:53:13 +0100 Subject: [MRG] Build POT against oldest-supported-numpy (local PR) (#349) * Configure setup to compile against oldest supported numpy version using the meta-package: https://pypi.org/project/oldest-supported-numpy/ - * Set minimum Python requirement to `>=3.7` in setup.py since !328 removed Python 3.6 support * Fix typo in pyproject.toml - * Update setup.py * Update setup.py and * build wheels * remove install dependencies for wheels building and build wheels * Apply suggestions from code review Co-authored-by: David M. Ghiurco <9147386+davidghiurco@users.noreply.github.com> * correct timing test add info in release file and build wheels * pep8 and Co-authored-by: David Ghiurco <9147386+davidghiurco@users.noreply.github.com> --- .github/workflows/build_wheels.yml | 8 +------- .github/workflows/build_wheels_weekly.yml | 2 -- RELEASES.md | 1 + docs/source/quickstart.rst | 2 +- pyproject.toml | 2 +- requirements.txt | 1 - setup.py | 5 +++-- test/test_utils.py | 4 +++- 8 files changed, 10 insertions(+), 15 deletions(-) diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index c746eb8..475058c 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -27,8 +27,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -U "cython" - name: Install cibuildwheel run: | @@ -37,7 +35,6 @@ jobs: - name: Build wheels env: CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp36*" # remove pypy on mac and win (wrong version) - CIBW_BEFORE_BUILD: "pip install numpy cython" run: | python -m cibuildwheel --output-dir wheelhouse @@ -65,8 +62,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -U "cython" - name: Install cibuildwheel run: | @@ -80,8 +75,7 @@ jobs: - name: Build wheels env: - CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl* cp36*" # remove pypy on mac and win (wrong version) - CIBW_BEFORE_BUILD: "pip install numpy cython" + CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl*" # remove pypy on mac and win (wrong version) CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU CIBW_ARCHS_MACOS: x86_64 universal2 arm64 run: | diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml index dbf342f..b9154c5 100644 --- a/.github/workflows/build_wheels_weekly.yml +++ b/.github/workflows/build_wheels_weekly.yml @@ -26,8 +26,6 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install -U "cython" - name: Install cibuildwheel run: | diff --git a/RELEASES.md b/RELEASES.md index 925920a..92b7ba5 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -16,6 +16,7 @@ - Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR #338) +- Make POT ABI compatible with old and new numpy (Issue #346, PR #349) ## 0.8.1.0 *December 2021* diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index e74b019..09a362b 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -1002,7 +1002,7 @@ FAQ 2. **pip install POT fails with error : ImportError: No module named Cython.Build** - As discussed shortly in the README file. POT requires to have :code:`numpy` + As discussed shortly in the README file. POT<0.8 requires to have :code:`numpy` and :code:`cython` installed to build. This corner case is not yet handled by :code:`pip` and for now you need to install both library prior to installing POT. diff --git a/pyproject.toml b/pyproject.toml index 93ebab3..3789206 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,3 @@ [build-system] -requires = ["setuptools", "wheel", "numpy>=1.20", "cython>=0.23"] +requires = ["setuptools", "wheel", "oldest-supported-numpy", "cython>=0.23"] build-backend = "setuptools.build_meta" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f9934ce..7cbb29a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ numpy>=1.20 scipy>=1.3 -cython matplotlib autograd pymanopt==0.2.4; python_version <'3' diff --git a/setup.py b/setup.py index d46ae1c..c03191a 100644 --- a/setup.py +++ b/setup.py @@ -68,8 +68,9 @@ setup( license='MIT', scripts=[], data_files=[], - setup_requires=["numpy>=1.20", "cython>=0.23"], - install_requires=["numpy>=1.20", "scipy>=1.0"], + setup_requires=["oldest-supported-numpy", "cython>=0.23"], + install_requires=["numpy>=1.16", "scipy>=1.0"], + python_requires=">=3.6", classifiers=[ 'Development Status :: 5 - Production/Stable', 'Intended Audience :: Developers', diff --git a/test/test_utils.py b/test/test_utils.py index 5ad167b..3cfd295 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -67,7 +67,9 @@ def test_tic_toc(): t2 = ot.toq() # test timing - np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1) + # np.testing.assert_allclose(0.1, t, rtol=1e-1, atol=1e-1) + # very slow macos github action equality not possible + assert t > 0.09 # test toc vs toq np.testing.assert_allclose(t, t2, rtol=1e-1, atol=1e-1) -- cgit v1.2.3 From 17814726200b4010afbf52701e8bcb132d678502 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 23 Feb 2022 15:16:48 +0100 Subject: [MRG] Proper links in release file in documentation (#350) * propreer links in release file in documentation * add pr in release file --- .circleci/config.yml | 5 +++++ RELEASES.md | 15 ++++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 5427979..39c19fb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -57,6 +57,11 @@ jobs: which python python -c "import ot" + - run: + name: Correct link in release file + command: | + sed -i -r 's/PR #([[:digit:]]*)/\[PR #\1\]\(https:\/\/github.com\/PythonOT\/POT\/pull\/\1\)/' RELEASES.md + sed -i -r 's/Issue #([[:digit:]]*)/\[Issue #\1\]\(https:\/\/github.com\/PythonOT\/POT\/issues\/\1\)/' RELEASES.md # Build docs - run: name: make html diff --git a/RELEASES.md b/RELEASES.md index 92b7ba5..c1068f3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -11,12 +11,13 @@ - Backend implementation for `ot.lp.free_support_barycenter` (PR #340). - Add weak OT solver + example (PR #341). - Add (F)GW linear dictionary learning solvers + example (PR #319) +- Add links to related PR and Issues in the doc release page (PR #350) #### Closed issues -- Bug in instantiating an `autograd` function (`ValFunction`, Issue #337, PR - #338) -- Make POT ABI compatible with old and new numpy (Issue #346, PR #349) +- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, + PR #338) +- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) ## 0.8.1.0 *December 2021* @@ -60,10 +61,10 @@ As always we want to that the contributors who helped make POT better (and bug f - Fix bug in older Numpy ABI (<1.20) (Issue #308, PR #326) - Fix bug in `ot.dist` function when non euclidean distance (Issue #305, PR #306) -- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309, PR - #310) -- Fix bug in generalized Conditional gradient solver and SinkhornL1L2 (Issue - #311, PR #313) +- Fix gradient scaling for functions using `nx.set_gradients` (Issue #309, + PR #310) +- Fix bug in generalized Conditional gradient solver and SinkhornL1L2 + (Issue #311, PR #313) - Fix log error in `gromov_barycenters` (Issue #317, PR #3018) ## 0.8.0 -- cgit v1.2.3 From 9412f0ad1c0003e659b7d779bf8b6728e0e5e60f Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Wed, 2 Mar 2022 11:35:47 +0100 Subject: [MRG] Gromov_Wasserstein2 not performing backward properly on GPU (#352) * Resolves gromov wasserstein backward bug * release file updated --- RELEASES.md | 3 +++ ot/gromov.py | 12 +++++++---- test/test_gromov.py | 60 +++++++++++++++++++++++++++++++---------------------- 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index c1068f3..18562e7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -18,6 +18,9 @@ - Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) +- Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA + tensors (Issue #351, PR #352) + ## 0.8.1.0 *December 2021* diff --git a/ot/gromov.py b/ot/gromov.py index f5a1f91..c5a82d1 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -546,8 +546,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= gw = log_gw['gw_dist'] if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) gw = nx.set_gradients(gw, (p0, q0, C10, C20), (log_gw['u'], log_gw['v'], gC1, gC2)) @@ -786,8 +788,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 log_fgw['T'] = T0 if loss_fun == 'square_loss': - gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) - gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T) + gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T) + gC1 = nx.from_numpy(gC1, type_as=C10) + gC2 = nx.from_numpy(gC2, type_as=C10) fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) diff --git a/test/test_gromov.py b/test/test_gromov.py index 329f99c..0dcf2da 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,19 +181,24 @@ def test_gromov2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) - val = ot.gromov_wasserstein2(C11, C12, p1, q1) + val = ot.gromov_wasserstein2(C11, C12, p1, q1) - val.backward() + val.backward() - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape @pytest.skip_backend("jax", reason="test very slow with jax backend") @@ -636,21 +641,26 @@ def test_fgw2_gradients(): if torch: - p1 = torch.tensor(p, requires_grad=True) - q1 = torch.tensor(q, requires_grad=True) - C11 = torch.tensor(C1, requires_grad=True) - C12 = torch.tensor(C2, requires_grad=True) - M1 = torch.tensor(M, requires_grad=True) - - val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) - - val.backward() - - assert q1.shape == q1.grad.shape - assert p1.shape == p1.grad.shape - assert C11.shape == C11.grad.shape - assert C12.shape == C12.grad.shape - assert M1.shape == M1.grad.shape + devices = [torch.device("cpu")] + if torch.cuda.is_available(): + devices.append(torch.device("cuda")) + for device in devices: + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape def test_fgw_barycenter(nx): -- cgit v1.2.3 From b47d5045d97ba0c0c98dada34ad643a26c1fb86e Mon Sep 17 00:00:00 2001 From: Keiland Date: Wed, 9 Mar 2022 13:09:15 -0800 Subject: [MRG] minor spelling fix (#353) --- examples/plot_OT_L1_vs_L2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py index 60353ab..cb94574 100644 --- a/examples/plot_OT_L1_vs_L2.py +++ b/examples/plot_OT_L1_vs_L2.py @@ -150,7 +150,7 @@ pl.clf() pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') -pl.title('Source and traget distributions') +pl.title('Source and target distributions') # Cost matrices -- cgit v1.2.3 From 9b9d2221d257f40ea3eb58b279b30d69162d62bb Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 18 Mar 2022 08:00:19 +0100 Subject: [MRG] Add logo to POT (#357) * add logo code and logo to doc * update release file --- RELEASES.md | 1 + docs/source/_static/images/logo.png | Bin 0 -> 5038 bytes docs/source/_static/images/logo.svg | 200 +++++++++++++++++++++++++++++++ docs/source/_static/images/logo_dark.png | Bin 0 -> 3437 bytes docs/source/_static/images/logo_dark.svg | 187 +++++++++++++++++++++++++++++ docs/source/conf.py | 4 +- docs/source/index.rst | 5 + examples/others/plot_logo.py | 112 +++++++++++++++++ 8 files changed, 508 insertions(+), 1 deletion(-) create mode 100644 docs/source/_static/images/logo.png create mode 100644 docs/source/_static/images/logo.svg create mode 100644 docs/source/_static/images/logo_dark.png create mode 100644 docs/source/_static/images/logo_dark.svg create mode 100644 examples/others/plot_logo.py diff --git a/RELEASES.md b/RELEASES.md index 18562e7..0f1f231 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features +- A brand new logo for POT (PR #357) - Better list of related examples in quick start guide with `minigallery` (PR #334). - Add optional log-domain Sinkhorn implementation in WDA to support smaller values of the regularization parameter (PR #336). diff --git a/docs/source/_static/images/logo.png b/docs/source/_static/images/logo.png new file mode 100644 index 0000000..7be5df7 Binary files /dev/null and b/docs/source/_static/images/logo.png differ diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg new file mode 100644 index 0000000..0bf2cb7 --- /dev/null +++ b/docs/source/_static/images/logo.svg @@ -0,0 +1,200 @@ + + + + + + + + + 2022-03-17T17:25:30.736761 + image/svg+xml + + + Matplotlib v3.3.3, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/_static/images/logo_dark.png b/docs/source/_static/images/logo_dark.png new file mode 100644 index 0000000..f484188 Binary files /dev/null and b/docs/source/_static/images/logo_dark.png differ diff --git a/docs/source/_static/images/logo_dark.svg b/docs/source/_static/images/logo_dark.svg new file mode 100644 index 0000000..56ce2d9 --- /dev/null +++ b/docs/source/_static/images/logo_dark.svg @@ -0,0 +1,187 @@ + + + + + + + + + 2022-03-17T17:25:30.847142 + image/svg+xml + + + Matplotlib v3.3.3, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/conf.py b/docs/source/conf.py index d1b8426..60d0bb7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -162,6 +162,7 @@ html_theme = 'sphinx_rtd_theme' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. + html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. @@ -176,7 +177,7 @@ html_theme_options = {} # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +html_logo = '_static/images/logo_dark.svg' # The name of an image file (relative to this directory) to use as a favicon of # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -188,6 +189,7 @@ html_theme_options = {} # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] + # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. diff --git a/docs/source/index.rst b/docs/source/index.rst index 8de31ae..7ff7d22 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -6,6 +6,10 @@ POT: Python Optimal Transport ============================= +.. image:: _static/images/logo.svg + :width: 400 + :alt: POT Logo + Contents -------- @@ -20,6 +24,7 @@ Contents .github/CONTRIBUTING .github/CODE_OF_CONDUCT + .. include:: ../../README.md :parser: myst_parser.sphinx_ diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py new file mode 100644 index 0000000..afddcad --- /dev/null +++ b/examples/others/plot_logo.py @@ -0,0 +1,112 @@ + +# -*- coding: utf-8 -*- +r""" +======================= +Logo of the POT toolbox +======================= + +In this example we plot the logo of the POT toolbox. + +A specificity of this logo is that it is done 100% in Python and generated using +matplotlib using the EMD solver from POT. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% +import numpy as np +import matplotlib.pyplot as pl +import ot + +# %% +# Data for logo +# ------------- + + +# Letter P +p1 = np.array([[0, 6.], [0, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) +p2 = np.array([[1.5, 6], [2, 4], [2, 5], [1.5, 3], [0.5, 2], [.5, 1], ]) + +# Letter O +o1 = np.array([[0, 6.], [-1, 5], [-1.5, 4], [-1.5, 3], [-1, 2], [0, 1], ]) +o2 = np.array([[1, 6.], [2, 5], [2.5, 4], [2.5, 3], [2, 2], [1, 1], ]) + +# scaling and translation for letter O +o1[:, 0] += 6.4 +o2[:, 0] += 6.4 +o1[:, 0] *= 0.6 +o2[:, 0] *= 0.6 + +# letter T +t1 = np.array([[-1, 6.], [-1, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) +t2 = np.array([[1.5, 6.], [1.5, 5], [0.5, 4], [0.5, 3], [0.5, 2], [0.5, 1], ]) + +# translatin the T +t1[:, 0] += 7.1 +t2[:, 0] += 7.1 + +# Cocatenate all letters +x1 = np.concatenate((p1, o1, t1), axis=0) +x2 = np.concatenate((p2, o2, t2), axis=0) + +# Horizontal and vertical scaling +sx = 1.0 +sy = .5 +x1[:, 0] *= sx +x1[:, 1] *= sy +x2[:, 0] *= sx +x2[:, 1] *= sy + +# %% +# Plot the logo (clear background) +# -------------------------------- + +# Solve OT problem between the points +M = ot.dist(x1, x2, metric='euclidean') +T = ot.emd([], [], M) + +pl.figure(1, (3.5, 1.1)) +pl.clf() +# plot the OT plan +for i in range(M.shape[0]): + for j in range(M.shape[1]): + if T[i, j] > 1e-8: + pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='k', alpha=0.6, linewidth=3, zorder=1) +# plot the samples +pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='C3', markeredgecolor='k') +pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='b', markeredgecolor='k') + + +pl.axis('equal') +pl.axis('off') + +# Save logo file +# pl.savefig('logo.svg', dpi=150, bbox_inches='tight') +# pl.savefig('logo.png', dpi=150, bbox_inches='tight') + +# %% +# Plot the logo (dark background) +# -------------------------------- + +pl.figure(2, (3.5, 1.1), facecolor='darkgray') +pl.clf() +# plot the OT plan +for i in range(M.shape[0]): + for j in range(M.shape[1]): + if T[i, j] > 1e-8: + pl.plot([x1[i, 0], x2[j, 0]], [x1[i, 1], x2[j, 1]], color='w', alpha=0.8, linewidth=3, zorder=1) +# plot the samples +pl.plot(x1[:, 0], x1[:, 1], 'o', markerfacecolor='w', markeredgecolor='w') +pl.plot(x2[:, 0], x2[:, 1], 'o', markerfacecolor='w', markeredgecolor='w') + +pl.axis('equal') +pl.axis('off') + +# Save logo file +# pl.savefig('logo_dark.svg', dpi=150, transparent=True, bbox_inches='tight') +# pl.savefig('logo_dark.png', dpi=150, transparent=True, bbox_inches='tight') -- cgit v1.2.3 From 767171593f2a98a26b9a39bf110a45085e3b982e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Thu, 24 Mar 2022 10:53:47 +0100 Subject: [MRG] Domain adaptation and unbalanced solvers with backend support (#343) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: Rémi Flamary --- RELEASES.md | 2 + ot/backend.py | 304 ++++++++++++++++++++++++++++++++++---- ot/bregman.py | 17 +-- ot/da.py | 382 +++++++++++++++++++++++++++--------------------- ot/lp/__init__.py | 83 ++++++++--- ot/optim.py | 11 +- ot/unbalanced.py | 302 +++++++++++++++++++------------------- ot/utils.py | 26 +++- test/test_1d_solver.py | 28 +--- test/test_backend.py | 66 ++++++++- test/test_bregman.py | 81 +++------- test/test_da.py | 307 +++++++++++++++++++------------------- test/test_gromov.py | 147 +++++++------------ test/test_optim.py | 17 +-- test/test_ot.py | 19 +-- test/test_sliced.py | 32 +--- test/test_unbalanced.py | 157 ++++++++++++-------- test/test_weak.py | 4 +- 18 files changed, 1160 insertions(+), 825 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 0f1f231..86b401a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -11,6 +11,7 @@ of the regularization parameter (PR #336). - Backend implementation for `ot.lp.free_support_barycenter` (PR #340). - Add weak OT solver + example (PR #341). +- Add backend support for Domain Adaptation and Unbalanced solvers (PR #343). - Add (F)GW linear dictionary learning solvers + example (PR #319) - Add links to related PR and Issues in the doc release page (PR #350) @@ -19,6 +20,7 @@ - Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) +- Warning when feeding integer cost matrix to EMD solver resulting in an integer transport plan (Issue #345, PR #343) - Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA tensors (Issue #351, PR #352) diff --git a/ot/backend.py b/ot/backend.py index 6e0bc3d..361ffba 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -87,7 +87,9 @@ Performance # License: MIT License import numpy as np -import scipy.special as scipy +import scipy +import scipy.linalg +import scipy.special as special from scipy.sparse import issparse, coo_matrix, csr_matrix import warnings import time @@ -102,7 +104,7 @@ except ImportError: try: import jax import jax.numpy as jnp - import jax.scipy.special as jscipy + import jax.scipy.special as jspecial from jax.lib import xla_bridge jax_type = jax.numpy.ndarray except ImportError: @@ -202,13 +204,29 @@ class Backend(): def __str__(self): return self.__name__ - # convert to numpy - def to_numpy(self, a): + # convert batch of tensors to numpy + def to_numpy(self, *arrays): + """Returns the numpy version of tensors""" + if len(arrays) == 1: + return self._to_numpy(arrays[0]) + else: + return [self._to_numpy(array) for array in arrays] + + # convert a tensor to numpy + def _to_numpy(self, a): """Returns the numpy version of a tensor""" raise NotImplementedError() - # convert from numpy - def from_numpy(self, a, type_as=None): + # convert batch of arrays from numpy + def from_numpy(self, *arrays, type_as=None): + """Creates tensors cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" + if len(arrays) == 1: + return self._from_numpy(arrays[0], type_as=type_as) + else: + return [self._from_numpy(array, type_as=type_as) for array in arrays] + + # convert an array from numpy + def _from_numpy(self, a, type_as=None): """Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)""" raise NotImplementedError() @@ -536,6 +554,16 @@ class Backend(): """ raise NotImplementedError() + def argmin(self, a, axis=None): + r""" + Returns the indices of the minimum values of a tensor along given dimensions. + + This function follows the api from :any:`numpy.argmin` + + See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html + """ + raise NotImplementedError() + def mean(self, a, axis=None): r""" Computes the arithmetic mean of a tensor along given dimensions. @@ -786,6 +814,72 @@ class Backend(): """ raise NotImplementedError() + def solve(self, a, b): + r""" + Solves a linear matrix equation, or system of linear scalar equations. + + This function follows the api from :any:`numpy.linalg.solve`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html + """ + raise NotImplementedError() + + def trace(self, a): + r""" + Returns the sum along diagonals of the array. + + This function follows the api from :any:`numpy.trace`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html + """ + raise NotImplementedError() + + def inv(self, a): + r""" + Computes the inverse of a matrix. + + This function follows the api from :any:`scipy.linalg.inv`. + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html + """ + raise NotImplementedError() + + def sqrtm(self, a): + r""" + Computes the matrix square root. Requires input to be definite positive. + + This function follows the api from :any:`scipy.linalg.sqrtm`. + + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html + """ + raise NotImplementedError() + + def isfinite(self, a): + r""" + Tests element-wise for finiteness (not infinity and not Not a Number). + + This function follows the api from :any:`numpy.isfinite`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html + """ + raise NotImplementedError() + + def array_equal(self, a, b): + r""" + True if two arrays have the same shape and elements, False otherwise. + + This function follows the api from :any:`numpy.array_equal`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html + """ + raise NotImplementedError() + + def is_floating_point(self, a): + r""" + Returns whether or not the input consists of floats + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -802,10 +896,10 @@ class NumpyBackend(Backend): rng_ = np.random.RandomState() - def to_numpy(self, a): + def _to_numpy(self, a): return a - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): if type_as is None: return a elif isinstance(a, float): @@ -936,6 +1030,9 @@ class NumpyBackend(Backend): def argmax(self, a, axis=None): return np.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return np.argmin(a, axis=axis) + def mean(self, a, axis=None): return np.mean(a, axis=axis) @@ -955,7 +1052,7 @@ class NumpyBackend(Backend): return np.unique(a) def logsumexp(self, a, axis=None): - return scipy.logsumexp(a, axis=axis) + return special.logsumexp(a, axis=axis) def stack(self, arrays, axis=0): return np.stack(arrays, axis) @@ -1004,8 +1101,11 @@ class NumpyBackend(Backend): else: return a - def where(self, condition, x, y): - return np.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return np.where(condition) + else: + return np.where(condition, x, y) def copy(self, a): return a.copy() @@ -1046,6 +1146,27 @@ class NumpyBackend(Backend): results[key] = (t1 - t0) / n_runs return results + def solve(self, a, b): + return np.linalg.solve(a, b) + + def trace(self, a): + return np.trace(a) + + def inv(self, a): + return scipy.linalg.inv(a) + + def sqrtm(self, a): + return scipy.linalg.sqrtm(a) + + def isfinite(self, a): + return np.isfinite(a) + + def array_equal(self, a, b): + return np.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.kind == "f" + class JaxBackend(Backend): """ @@ -1075,13 +1196,15 @@ class JaxBackend(Backend): jax.device_put(jnp.array(1, dtype=jnp.float64), d) ] - def to_numpy(self, a): + def _to_numpy(self, a): return np.array(a) def _change_device(self, a, type_as): return jax.device_put(a, type_as.device_buffer.device()) - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return jnp.array(a) else: @@ -1216,6 +1339,9 @@ class JaxBackend(Backend): def argmax(self, a, axis=None): return jnp.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return jnp.argmin(a, axis=axis) + def mean(self, a, axis=None): return jnp.mean(a, axis=axis) @@ -1235,7 +1361,7 @@ class JaxBackend(Backend): return jnp.unique(a) def logsumexp(self, a, axis=None): - return jscipy.logsumexp(a, axis=axis) + return jspecial.logsumexp(a, axis=axis) def stack(self, arrays, axis=0): return jnp.stack(arrays, axis) @@ -1293,8 +1419,11 @@ class JaxBackend(Backend): # Currently, JAX does not support sparse matrices return a - def where(self, condition, x, y): - return jnp.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return jnp.where(condition) + else: + return jnp.where(condition, x, y) def copy(self, a): # No need to copy, JAX arrays are immutable @@ -1339,6 +1468,28 @@ class JaxBackend(Backend): results[key] = (t1 - t0) / n_runs return results + def solve(self, a, b): + return jnp.linalg.solve(a, b) + + def trace(self, a): + return jnp.trace(a) + + def inv(self, a): + return jnp.linalg.inv(a) + + def sqrtm(self, a): + L, V = jnp.linalg.eigh(a) + return (V * jnp.sqrt(L)[None, :]) @ V.T + + def isfinite(self, a): + return jnp.isfinite(a) + + def array_equal(self, a, b): + return jnp.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.kind == "f" + class TorchBackend(Backend): """ @@ -1384,10 +1535,10 @@ class TorchBackend(Backend): self.ValFunction = ValFunction - def to_numpy(self, a): + def _to_numpy(self, a): return a.cpu().detach().numpy() - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): if isinstance(a, float): a = np.array(a) if type_as is None: @@ -1564,6 +1715,9 @@ class TorchBackend(Backend): def argmax(self, a, axis=None): return torch.argmax(a, dim=axis) + def argmin(self, a, axis=None): + return torch.argmin(a, dim=axis) + def mean(self, a, axis=None): if axis is not None: return torch.mean(a, dim=axis) @@ -1580,8 +1734,11 @@ class TorchBackend(Backend): return torch.linspace(start, stop, num, dtype=torch.float64) def meshgrid(self, a, b): - X, Y = torch.meshgrid(a, b) - return X.T, Y.T + try: + return torch.meshgrid(a, b, indexing="xy") + except TypeError: + X, Y = torch.meshgrid(a, b) + return X.T, Y.T def diag(self, a, k=0): return torch.diag(a, diagonal=k) @@ -1659,8 +1816,11 @@ class TorchBackend(Backend): else: return a - def where(self, condition, x, y): - return torch.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return torch.where(condition) + else: + return torch.where(condition, x, y) def copy(self, a): return torch.clone(a) @@ -1718,6 +1878,28 @@ class TorchBackend(Backend): torch.cuda.empty_cache() return results + def solve(self, a, b): + return torch.linalg.solve(a, b) + + def trace(self, a): + return torch.trace(a) + + def inv(self, a): + return torch.linalg.inv(a) + + def sqrtm(self, a): + L, V = torch.linalg.eigh(a) + return (V * torch.sqrt(L)[None, :]) @ V.T + + def isfinite(self, a): + return torch.isfinite(a) + + def array_equal(self, a, b): + return torch.equal(a, b) + + def is_floating_point(self, a): + return a.dtype.is_floating_point + class CupyBackend(Backend): # pragma: no cover """ @@ -1741,10 +1923,12 @@ class CupyBackend(Backend): # pragma: no cover cp.array(1, dtype=cp.float64) ] - def to_numpy(self, a): + def _to_numpy(self, a): return cp.asnumpy(a) - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if type_as is None: return cp.asarray(a) else: @@ -1884,6 +2068,9 @@ class CupyBackend(Backend): # pragma: no cover def argmax(self, a, axis=None): return cp.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return cp.argmin(a, axis=axis) + def mean(self, a, axis=None): return cp.mean(a, axis=axis) @@ -1982,8 +2169,11 @@ class CupyBackend(Backend): # pragma: no cover else: return a - def where(self, condition, x, y): - return cp.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return cp.where(condition) + else: + return cp.where(condition, x, y) def copy(self, a): return a.copy() @@ -2035,6 +2225,28 @@ class CupyBackend(Backend): # pragma: no cover pinned_mempool.free_all_blocks() return results + def solve(self, a, b): + return cp.linalg.solve(a, b) + + def trace(self, a): + return cp.trace(a) + + def inv(self, a): + return cp.linalg.inv(a) + + def sqrtm(self, a): + L, V = cp.linalg.eigh(a) + return (V * self.sqrt(L)[None, :]) @ V.T + + def isfinite(self, a): + return cp.isfinite(a) + + def array_equal(self, a, b): + return cp.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.kind == "f" + class TensorflowBackend(Backend): @@ -2060,13 +2272,16 @@ class TensorflowBackend(Backend): "To use TensorflowBackend, you need to activate the tensorflow " "numpy API. You can activate it by running: \n" "from tensorflow.python.ops.numpy_ops import np_config\n" - "np_config.enable_numpy_behavior()" + "np_config.enable_numpy_behavior()", + stacklevel=2 ) - def to_numpy(self, a): + def _to_numpy(self, a): return a.numpy() - def from_numpy(self, a, type_as=None): + def _from_numpy(self, a, type_as=None): + if isinstance(a, float): + a = np.array(a) if not isinstance(a, self.__type__): if type_as is None: return tf.convert_to_tensor(a) @@ -2208,6 +2423,9 @@ class TensorflowBackend(Backend): def argmax(self, a, axis=None): return tnp.argmax(a, axis=axis) + def argmin(self, a, axis=None): + return tnp.argmin(a, axis=axis) + def mean(self, a, axis=None): return tnp.mean(a, axis=axis) @@ -2309,8 +2527,11 @@ class TensorflowBackend(Backend): else: return a - def where(self, condition, x, y): - return tnp.where(condition, x, y) + def where(self, condition, x=None, y=None): + if x is None and y is None: + return tnp.where(condition) + else: + return tnp.where(condition, x, y) def copy(self, a): return tf.identity(a) @@ -2364,3 +2585,24 @@ class TensorflowBackend(Backend): results[key] = (t1 - t0) / n_runs return results + + def solve(self, a, b): + return tf.linalg.solve(a, b) + + def trace(self, a): + return tf.linalg.trace(a) + + def inv(self, a): + return tf.linalg.inv(a) + + def sqrtm(self, a): + return tf.linalg.sqrtm(a) + + def isfinite(self, a): + return tnp.isfinite(a) + + def array_equal(self, a, b): + return tnp.array_equal(a, b) + + def is_floating_point(self, a): + return a.dtype.is_floating diff --git a/ot/bregman.py b/ot/bregman.py index fc20175..c06af2f 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -2525,8 +2525,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, # geometric interpolation delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new)) K = projR(K, delta) - K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0) - + K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0 err = nx.norm(nx.sum(K0, axis=1) - old) old = new if log: @@ -2656,16 +2655,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, classes = nx.unique(Ys[d]) # build the corresponding D_1 and D_2 matrices - Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) - Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0]) + Dtmp1 = np.zeros((nbclasses, nsk)) + Dtmp2 = np.zeros((nbclasses, nsk)) for c in classes: - nbelemperclass = nx.sum(Ys[d] == c) + nbelemperclass = float(nx.sum(Ys[d] == c)) if nbelemperclass != 0: - Dtmp1[int(c), Ys[d] == c] = 1. - Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass) - D1.append(Dtmp1) - D2.append(Dtmp2) + Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1. + Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass) + D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0])) + D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0])) # build the cost matrix and the Gibbs kernel Mtmp = dist(Xs[d], Xt, metric=metric) diff --git a/ot/da.py b/ot/da.py index 841f31a..0b9737e 100644 --- a/ot/da.py +++ b/ot/da.py @@ -12,12 +12,12 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np -import scipy.linalg as linalg +from .backend import get_backend from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import check_params, BaseEstimator +from .utils import list_to_array, check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg from .optim import gcg @@ -60,13 +60,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - labels_a : np.ndarray (ns,) + labels_a : array-like (ns,) labels of samples in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples weights in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix reg : float Regularization term for entropic regularization >0 @@ -86,7 +86,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -111,26 +111,28 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, ot.optim.cg : General regularized OT """ + a, labels_a, b, M = list_to_array(a, labels_a, b, M) + nx = get_backend(a, labels_a, b, M) + p = 0.5 epsilon = 1e-3 indices_labels = [] - classes = np.unique(labels_a) + classes = nx.unique(labels_a) for c in classes: - idxc, = np.where(labels_a == c) + idxc, = nx.where(labels_a == c) indices_labels.append(idxc) - W = np.zeros(M.shape) - + W = nx.zeros(M.shape, type_as=M) for cpt in range(numItermax): Mreg = M + eta * W transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, stopThr=stopInnerThr) # the transport has been computed. Check if classes are really # separated - W = np.ones(M.shape) + W = nx.ones(M.shape, type_as=M) for (i, c) in enumerate(classes): - majs = np.sum(transp[indices_labels[i]], axis=0) + majs = nx.sum(transp[indices_labels[i]], axis=0) majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs @@ -174,13 +176,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - labels_a : np.ndarray (ns,) + labels_a : array-like (ns,) labels of samples in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix reg : float Regularization term for entropic regularization >0 @@ -200,7 +202,7 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -222,22 +224,25 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, ot.optim.gcg : Generalized conditional gradient for OT problems """ - lstlab = np.unique(labels_a) + a, labels_a, b, M = list_to_array(a, labels_a, b, M) + nx = get_backend(a, labels_a, b, M) + + lstlab = nx.unique(labels_a) def f(G): res = 0 for i in range(G.shape[1]): for lab in lstlab: temp = G[labels_a == lab, i] - res += np.linalg.norm(temp) + res += nx.norm(temp) return res def df(G): - W = np.zeros(G.shape) + W = nx.zeros(G.shape, type_as=G) for i in range(G.shape[1]): for lab in lstlab: temp = G[labels_a == lab, i] - n = np.linalg.norm(temp) + n = nx.norm(temp) if n: W[labels_a == lab, i] = temp / n return W @@ -289,9 +294,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain mu : float,optional Weight for the linear OT loss (>0) @@ -315,9 +320,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters - L : (d, d) ndarray + L : (d, d) array-like Linear mapping matrix ((:math:`d+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters @@ -336,13 +341,15 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, ot.optim.cg : General regularized OT """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1] if bias: - xs1 = np.hstack((xs, np.ones((ns, 1)))) - xstxs = xs1.T.dot(xs1) - Id = np.eye(d + 1) + xs1 = nx.concatenate((xs, nx.ones((ns, 1), type_as=xs)), axis=1) + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d + 1, type_as=xs) Id[-1] = 0 I0 = Id[:, :-1] @@ -350,8 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, return x[:-1, :] else: xs1 = xs - xstxs = xs1.T.dot(xs1) - Id = np.eye(d) + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d, type_as=xs) I0 = Id def sel(x): @@ -360,7 +367,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, if log: log = {'err': []} - a, b = unif(ns), unif(nt) + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) M = dist(xs, xt) * ns G = emd(a, b, M) @@ -368,23 +376,26 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, def loss(L, G): """Compute full loss""" - return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ - np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2) + return ( + nx.sum((nx.dot(xs1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.sum(sel(L - I0) ** 2) + ) def solve_L(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0) + xst = ns * nx.dot(G, xt) + return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0) def solve_G(L, G0): """Update G with CG algorithm""" - xsi = xs1.dot(L) + xsi = nx.dot(xs1, L) def f(G): - return np.sum((xsi - ns * G.dot(xt)) ** 2) + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) def df(G): - return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) @@ -481,9 +492,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain mu : float,optional Weight for the linear OT loss (>0) @@ -513,9 +524,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters - L : (ns, d) ndarray + L : (ns, d) array-like Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters @@ -534,15 +545,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', ot.optim.cg : General regularized OT """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) ns, nt = xs.shape[0], xt.shape[0] K = kernel(xs, xs, method=kerneltype, sigma=sigma) if bias: - K1 = np.hstack((K, np.ones((ns, 1)))) - Id = np.eye(ns + 1) + K1 = nx.concatenate((K, nx.ones((ns, 1), type_as=xs)), axis=1) + Id = nx.eye(ns + 1, type_as=xs) Id[-1] = 0 - Kp = np.eye(ns + 1) + Kp = nx.eye(ns + 1, type_as=xs) Kp[:ns, :ns] = K # ls regu @@ -550,12 +563,12 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', # Kreg=I # RKHS regul - K0 = K1.T.dot(K1) + eta * Kp + K0 = nx.dot(K1.T, K1) + eta * Kp Kreg = Kp else: K1 = K - Id = np.eye(ns) + Id = nx.eye(ns, type_as=xs) # ls regul # K0 = K1.T.dot(K1)+eta*I @@ -568,7 +581,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', if log: log = {'err': []} - a, b = unif(ns), unif(nt) + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) M = dist(xs, xt) * ns G = emd(a, b, M) @@ -576,28 +590,31 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def loss(L, G): """Compute full loss""" - return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ - np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) + return ( + nx.sum((nx.dot(K1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.trace(dots(L.T, Kreg, L)) + ) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(K0, xst) + xst = ns * nx.dot(G, xt) + return nx.solve(K0, xst) def solve_L_bias(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(K0, K1.T.dot(xst)) + xst = ns * nx.dot(G, xt) + return nx.solve(K0, nx.dot(K1.T, xst)) def solve_G(L, G0): """Update G with CG algorithm""" - xsi = K1.dot(L) + xsi = nx.dot(K1, L) def f(G): - return np.sum((xsi - ns * G.dot(xt)) ** 2) + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) def df(G): - return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) @@ -681,15 +698,15 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain reg : float,optional regularization added to the diagonals of covariances (>0) - ws : np.ndarray (ns,1), optional + ws : array-like (ns,1), optional weights for the source samples - wt : np.ndarray (ns,1), optional + wt : array-like (ns,1), optional weights for the target samples bias: boolean, optional estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) @@ -699,9 +716,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, Returns ------- - A : (d, d) ndarray + A : (d, d) array-like Linear operator - b : (1, d) ndarray + b : (1, d) array-like bias log : dict log dictionary return only if log==True in parameters @@ -719,36 +736,38 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) d = xs.shape[1] if bias: - mxs = xs.mean(0, keepdims=True) - mxt = xt.mean(0, keepdims=True) + mxs = nx.mean(xs, axis=0)[None, :] + mxt = nx.mean(xt, axis=0)[None, :] xs = xs - mxs xt = xt - mxt else: - mxs = np.zeros((1, d)) - mxt = np.zeros((1, d)) + mxs = nx.zeros((1, d), type_as=xs) + mxt = nx.zeros((1, d), type_as=xs) if ws is None: - ws = np.ones((xs.shape[0], 1)) / xs.shape[0] + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: - wt = np.ones((xt.shape[0], 1)) / xt.shape[0] + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d) - Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d) + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) - Cs12 = linalg.sqrtm(Cs) - Cs_12 = linalg.inv(Cs12) + Cs12 = nx.sqrtm(Cs) + Cs_12 = nx.inv(Cs12) - M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) - A = Cs_12.dot(M0.dot(Cs_12)) + A = dots(Cs_12, M0, Cs_12) - b = mxt - mxs.dot(A) + b = mxt - nx.dot(mxs, A) if log: log = {} @@ -798,15 +817,15 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples weights in the target domain - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix sim : string, optional Type of similarity ('knn' or 'gauss') used to construct the Laplacian. @@ -834,7 +853,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -862,9 +881,12 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al raise ValueError( 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__)) + a, b, xs, xt, M = list_to_array(a, b, xs, xt, M) + nx = get_backend(a, b, xs, xt, M) + if sim == 'gauss': if sim_param is None: - sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sim_param = 1 / (2 * (nx.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) sS = kernel(xs, xs, method=sim, sigma=sim_param) sT = kernel(xt, xt, method=sim, sigma=sim_param) @@ -874,9 +896,13 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al from sklearn.neighbors import kneighbors_graph - sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray() + sS = nx.from_numpy(kneighbors_graph( + X=nx.to_numpy(xs), n_neighbors=int(sim_param) + ).toarray(), type_as=xs) sS = (sS + sS.T) / 2 - sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray() + sT = nx.from_numpy(kneighbors_graph( + X=nx.to_numpy(xt), n_neighbors=int(sim_param) + ).toarray(), type_as=xt) sT = (sT + sT.T) / 2 else: raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim)) @@ -885,12 +911,14 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al lT = laplacian(sT) def f(G): - return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \ - + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs))))) + return ( + alpha * nx.trace(dots(xt.T, G.T, lS, G, xt)) + + (1 - alpha) * nx.trace(dots(xs.T, G, lT, G.T, xs)) + ) ls2 = lS + lS.T lt2 = lT + lT.T - xt2 = np.dot(xt, xt.T) + xt2 = nx.dot(xt, xt.T) if reg == 'disp': Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T) @@ -898,8 +926,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al M = M + Cs + Ct def df(G): - return alpha * np.dot(ls2, np.dot(G, xt2))\ - + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2))) + return ( + alpha * dots(ls2, G, xt2) + + (1 - alpha) * dots(xs, xs.T, G, lt2) + ) return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax, stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log) @@ -919,7 +949,7 @@ def distribution_estimation_uniform(X): The uniform distribution estimated from :math:`\mathbf{X}` """ - return unif(X.shape[0]) + return unif(X.shape[0], type_as=X) class BaseTransport(BaseEstimator): @@ -973,6 +1003,7 @@ class BaseTransport(BaseEstimator): self : object Returns self. """ + nx = self._get_backend(Xs, ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): @@ -984,14 +1015,14 @@ class BaseTransport(BaseEstimator): if (ys is not None) and (yt is not None): if self.limit_max != np.infty: - self.limit_max = self.limit_max * np.max(self.cost_) + self.limit_max = self.limit_max * nx.max(self.cost_) # assumes labeled source samples occupy the first rows # and labeled target samples occupy the first columns - classes = [c for c in np.unique(ys) if c != -1] + classes = [c for c in nx.unique(ys) if c != -1] for c in classes: - idx_s = np.where((ys != c) & (ys != -1)) - idx_t = np.where(yt == c) + idx_s = nx.where((ys != c) & (ys != -1)) + idx_t = nx.where(yt == c) # all the coefficients corresponding to a source sample # and a target sample : @@ -1062,23 +1093,24 @@ class BaseTransport(BaseEstimator): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - if np.array_equal(self.xs_, Xs): + if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, self.xt_) + transp_Xs = nx.dot(transp, self.xt_) else: # perform out of sample mapping - indices = np.arange(Xs.shape[0]) + indices = nx.arange(Xs.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -1087,20 +1119,20 @@ class BaseTransport(BaseEstimator): for bi in batch_ind: # get the nearest neighbor in the source domain D0 = dist(Xs[bi], self.xs_) - idx = np.argmin(D0, axis=1) + idx = nx.argmin(D0, axis=1) # transport the source samples - transp = self.coupling_ / np.sum( - self.coupling_, 1)[:, None] - transp[~ np.isfinite(transp)] = 0 - transp_Xs_ = np.dot(transp, self.xt_) + transp = self.coupling_ / nx.sum( + self.coupling_, axis=1)[:, None] + transp[~ nx.isfinite(transp)] = 0 + transp_Xs_ = nx.dot(transp, self.xt_) # define the transported points transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :] transp_Xs.append(transp_Xs_) - transp_Xs = np.concatenate(transp_Xs, axis=0) + transp_Xs = nx.concatenate(transp_Xs, axis=0) return transp_Xs @@ -1127,26 +1159,27 @@ class BaseTransport(BaseEstimator): International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(ys=ys): - ysTemp = label_normalization(np.copy(ys)) - classes = np.unique(ysTemp) + ysTemp = label_normalization(nx.copy(ys)) + classes = nx.unique(ysTemp) n = len(classes) - D1 = np.zeros((n, len(ysTemp))) + D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True) + transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 for c in classes: D1[int(c), ysTemp == c] = 1 # compute propagated labels - transp_ys = np.dot(D1, transp) + transp_ys = nx.dot(D1, transp) return transp_ys.T @@ -1176,23 +1209,24 @@ class BaseTransport(BaseEstimator): transp_Xt : array-like, shape (n_source_samples, n_features) The transported target samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xt=Xt): - if np.array_equal(self.xt_, Xt): + if nx.array_equal(self.xt_, Xt): # perform standard barycentric mapping - transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None] + transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] # set nans to 0 - transp_[~ np.isfinite(transp_)] = 0 + transp_[~ nx.isfinite(transp_)] = 0 # compute transported samples - transp_Xt = np.dot(transp_, self.xs_) + transp_Xt = nx.dot(transp_, self.xs_) else: # perform out of sample mapping - indices = np.arange(Xt.shape[0]) + indices = nx.arange(Xt.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -1200,20 +1234,20 @@ class BaseTransport(BaseEstimator): transp_Xt = [] for bi in batch_ind: D0 = dist(Xt[bi], self.xt_) - idx = np.argmin(D0, axis=1) + idx = nx.argmin(D0, axis=1) # transport the target samples - transp_ = self.coupling_.T / np.sum( + transp_ = self.coupling_.T / nx.sum( self.coupling_, 0)[:, None] - transp_[~ np.isfinite(transp_)] = 0 - transp_Xt_ = np.dot(transp_, self.xs_) + transp_[~ nx.isfinite(transp_)] = 0 + transp_Xt_ = nx.dot(transp_, self.xs_) # define the transported points transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :] transp_Xt.append(transp_Xt_) - transp_Xt = np.concatenate(transp_Xt, axis=0) + transp_Xt = nx.concatenate(transp_Xt, axis=0) return transp_Xt @@ -1230,26 +1264,27 @@ class BaseTransport(BaseEstimator): transp_ys : array-like, shape (n_source_samples, nb_classes) Estimated soft source labels. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(yt=yt): - ytTemp = label_normalization(np.copy(yt)) - classes = np.unique(ytTemp) + ytTemp = label_normalization(nx.copy(yt)) + classes = nx.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(ytTemp))) + D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 for c in classes: D1[int(c), ytTemp == c] = 1 # compute propagated samples - transp_ys = np.dot(D1, transp.T) + transp_ys = nx.dot(D1, transp.T) return transp_ys.T @@ -1330,14 +1365,15 @@ class LinearTransport(BaseTransport): self : object Returns self. """ + nx = self._get_backend(Xs, ys, Xt, yt) self.mu_s = self.distribution_estimation(Xs) self.mu_t = self.distribution_estimation(Xt) # coupling estimation returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, - ws=self.mu_s.reshape((-1, 1)), - wt=self.mu_t.reshape((-1, 1)), + ws=nx.reshape(self.mu_s, (-1, 1)), + wt=nx.reshape(self.mu_t, (-1, 1)), bias=self.bias, log=self.log) # deal with the value of log @@ -1348,8 +1384,8 @@ class LinearTransport(BaseTransport): self.log_ = dict() # re compute inverse mapping - self.A1_ = linalg.inv(self.A_) - self.B1_ = -self.B_.dot(self.A1_) + self.A1_ = nx.inv(self.A_) + self.B1_ = -nx.dot(self.B_, self.A1_) return self @@ -1378,10 +1414,11 @@ class LinearTransport(BaseTransport): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - transp_Xs = Xs.dot(self.A_) + self.B_ + transp_Xs = nx.dot(Xs, self.A_) + self.B_ return transp_Xs @@ -1411,10 +1448,11 @@ class LinearTransport(BaseTransport): transp_Xt : array-like, shape (n_source_samples, n_features) The transported target samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xt=Xt): - transp_Xt = Xt.dot(self.A1_) + self.B1_ + transp_Xt = nx.dot(Xt, self.A1_) + self.B1_ return transp_Xt @@ -2112,6 +2150,7 @@ class MappingTransport(BaseEstimator): self : object Returns self """ + self._get_backend(Xs, ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): @@ -2158,19 +2197,20 @@ class MappingTransport(BaseEstimator): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - if np.array_equal(self.xs_, Xs): + if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, self.xt_) + transp_Xs = nx.dot(transp, self.xt_) else: if self.kernel == "gaussian": K = kernel(Xs, self.xs_, method=self.kernel, @@ -2178,8 +2218,10 @@ class MappingTransport(BaseEstimator): elif self.kernel == "linear": K = Xs if self.bias: - K = np.hstack((K, np.ones((Xs.shape[0], 1)))) - transp_Xs = K.dot(self.mapping_) + K = nx.concatenate( + [K, nx.ones((Xs.shape[0], 1), type_as=K)], axis=1 + ) + transp_Xs = nx.dot(K, self.mapping_) return transp_Xs @@ -2396,6 +2438,7 @@ class JCPOTTransport(BaseTransport): self : object Returns self. """ + self._get_backend(*Xs, *ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): @@ -2438,28 +2481,29 @@ class JCPOTTransport(BaseTransport): batch_size : int, optional (default=128) The batch size for out of sample inverse transform """ + nx = self.nx transp_Xs = [] # check the necessary inputs parameters are here if check_params(Xs=Xs): - if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]): + if all([nx.allclose(x, y) for x, y in zip(self.xs_, Xs)]): # perform standard barycentric mapping for each source domain for coupling in self.coupling_: - transp = coupling / np.sum(coupling, 1)[:, None] + transp = coupling / nx.sum(coupling, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs.append(np.dot(transp, self.xt_)) + transp_Xs.append(nx.dot(transp, self.xt_)) else: # perform out of sample mapping - indices = np.arange(Xs.shape[0]) + indices = nx.arange(Xs.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -2470,23 +2514,22 @@ class JCPOTTransport(BaseTransport): transp_Xs_ = [] # get the nearest neighbor in the sources domains - xs = np.concatenate(self.xs_, axis=0) - idx = np.argmin(dist(Xs[bi], xs), axis=1) + xs = nx.concatenate(self.xs_, axis=0) + idx = nx.argmin(dist(Xs[bi], xs), axis=1) # transport the source samples for coupling in self.coupling_: - transp = coupling / np.sum( - coupling, 1)[:, None] - transp[~ np.isfinite(transp)] = 0 - transp_Xs_.append(np.dot(transp, self.xt_)) + transp = coupling / nx.sum(coupling, 1)[:, None] + transp[~ nx.isfinite(transp)] = 0 + transp_Xs_.append(nx.dot(transp, self.xt_)) - transp_Xs_ = np.concatenate(transp_Xs_, axis=0) + transp_Xs_ = nx.concatenate(transp_Xs_, axis=0) # define the transported points transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :] transp_Xs.append(transp_Xs_) - transp_Xs = np.concatenate(transp_Xs, axis=0) + transp_Xs = nx.concatenate(transp_Xs, axis=0) return transp_Xs @@ -2512,32 +2555,36 @@ class JCPOTTransport(BaseTransport): "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(ys=ys): - yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0])) + yt = nx.zeros( + (len(nx.unique(nx.concatenate(ys))), self.xt_.shape[0]), + type_as=ys[0] + ) for i in range(len(ys)): - ysTemp = label_normalization(np.copy(ys[i])) - classes = np.unique(ysTemp) + ysTemp = label_normalization(nx.copy(ys[i])) + classes = nx.unique(ysTemp) n = len(classes) ns = len(ysTemp) # perform label propagation - transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 if self.log: D1 = self.log_['D1'][i] else: - D1 = np.zeros((n, ns)) + D1 = nx.zeros((n, ns), type_as=transp) for c in classes: D1[int(c), ysTemp == c] = 1 # compute propagated labels - yt = yt + np.dot(D1, transp) / len(ys) + yt = yt + nx.dot(D1, transp) / len(ys) return yt.T @@ -2555,14 +2602,15 @@ class JCPOTTransport(BaseTransport): transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes) A list of estimated soft source labels """ + nx = self.nx # check the necessary inputs parameters are here if check_params(yt=yt): transp_ys = [] - ytTemp = label_normalization(np.copy(yt)) - classes = np.unique(ytTemp) + ytTemp = label_normalization(nx.copy(yt)) + classes = nx.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(ytTemp))) + D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0]) for c in classes: D1[int(c), ytTemp == c] = 1 @@ -2570,12 +2618,12 @@ class JCPOTTransport(BaseTransport): for i in range(len(self.xs_)): # perform label propagation - transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute propagated labels - transp_ys.append(np.dot(D1, transp.T).T) + transp_ys.append(nx.dot(D1, transp.T).T) return transp_ys diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d9b6fa9..abf7fe0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -225,6 +225,13 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays. + .. note:: This function will cast the computed transport plan to the data type + of the provided input with the following priority: :math:`\mathbf{a}`, + then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -290,12 +297,16 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M + if len(a0) != 0: + type_as = a0 + elif len(b0) != 0: + type_as = b0 + else: + type_as = M0 nx = get_backend(M0, a0, b0) # convert to numpy - M = nx.to_numpy(M) - a = nx.to_numpy(a) - b = nx.to_numpy(b) + M, a, b = nx.to_numpy(M, a, b) # ensure float64 a = np.asarray(a, dtype=np.float64) @@ -330,15 +341,23 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): u, v = estimate_dual_null_weights(u, v, a, b, M) result_code_string = check_result(result_code) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2 + ) if log: log = {} log['cost'] = cost - log['u'] = nx.from_numpy(u, type_as=a0) - log['v'] = nx.from_numpy(v, type_as=b0) + log['u'] = nx.from_numpy(u, type_as=type_as) + log['v'] = nx.from_numpy(v, type_as=type_as) log['warning'] = result_code_string log['result_code'] = result_code - return nx.from_numpy(G, type_as=M0), log - return nx.from_numpy(G, type_as=M0) + return nx.from_numpy(G, type_as=type_as), log + return nx.from_numpy(G, type_as=type_as) def emd2(a, b, M, processes=1, @@ -364,6 +383,14 @@ def emd2(a, b, M, processes=1, from all compatible backends. But the algorithm uses the C++ CPU backend which can lead to copy overhead on GPU arrays. + .. note:: This function will cast the computed transport plan and + transportation loss to the data type of the provided input with the + following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`, + then :math:`\mathbf{M}` if marginals are not provided. + Casting to an integer tensor might result in a loss of precision. + If this behaviour is unwanted, please make sure to provide a + floating point input. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -432,12 +459,16 @@ def emd2(a, b, M, processes=1, a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M + if len(a0) != 0: + type_as = a0 + elif len(b0) != 0: + type_as = b0 + else: + type_as = M0 nx = get_backend(M0, a0, b0) # convert to numpy - M = nx.to_numpy(M) - a = nx.to_numpy(a) - b = nx.to_numpy(b) + M, a, b = nx.to_numpy(M, a, b) a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) @@ -470,14 +501,22 @@ def emd2(a, b, M, processes=1, result_code_string = check_result(result_code) log = {} - G = nx.from_numpy(G, type_as=M0) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2 + ) + G = nx.from_numpy(G, type_as=type_as) if return_matrix: log['G'] = G - log['u'] = nx.from_numpy(u, type_as=a0) - log['v'] = nx.from_numpy(v, type_as=b0) + log['u'] = nx.from_numpy(u, type_as=type_as) + log['v'] = nx.from_numpy(v, type_as=type_as) log['warning'] = result_code_string log['result_code'] = result_code - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), (a0, b0, M0), (log['u'], log['v'], G)) return [cost, log] else: @@ -491,10 +530,18 @@ def emd2(a, b, M, processes=1, if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) - G = nx.from_numpy(G, type_as=M0) - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0, b0, M0), (nx.from_numpy(u, type_as=a0), - nx.from_numpy(v, type_as=b0), G)) + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2 + ) + G = nx.from_numpy(G, type_as=type_as) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), + (a0, b0, M0), (nx.from_numpy(u, type_as=type_as), + nx.from_numpy(v, type_as=type_as), G)) check_result(result_code) return cost diff --git a/ot/optim.py b/ot/optim.py index f25e2c9..5a1d605 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -9,12 +9,19 @@ Generic solvers for regularized OT # License: MIT License import numpy as np -from scipy.optimize.linesearch import scalar_search_armijo +import warnings from .lp import emd from .bregman import sinkhorn -from ot.utils import list_to_array +from .utils import list_to_array from .backend import get_backend +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + from scipy.optimize import scalar_search_armijo + except ImportError: + from scipy.optimize.linesearch import scalar_search_armijo + # The corresponding scipy function does not work for matrices diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 15e180b..503cc1e 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -8,9 +8,9 @@ Regularized Unbalanced OT solvers from __future__ import division import warnings -import numpy as np -from scipy.special import logsumexp +from .backend import get_backend +from .utils import list_to_array # from .utils import unif, dist @@ -43,12 +43,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -70,12 +70,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -172,12 +172,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -198,7 +198,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Returns ------- - ot_distance : (n_hists,) ndarray + ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` log : dict log dictionary returned only if `log` is `True` @@ -239,9 +239,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] ` """ - b = np.asarray(b, dtype=np.float64) + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] + if method.lower() == 'sinkhorn': return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=numItermax, @@ -291,12 +292,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b` If many, compute all the OT distances (a, b_i) - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -315,12 +316,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -354,17 +355,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -377,17 +376,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, 1)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, 1), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(M / (-reg)) fi = reg_m / (reg_m + reg) @@ -397,14 +393,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, uprev = u vprev = v - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (a / Kv) ** fi - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = (b / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -412,8 +408,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, v = vprev break - err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.) - err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.) + err_u = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) + err_v = nx.max(nx.abs(v - vprev)) / max( + nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. + ) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) @@ -426,11 +426,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, break if log: - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: @@ -475,12 +475,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Parameters ---------- - a : np.ndarray (dim_a,) + a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` - b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists) + b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` - M : np.ndarray (dim_a, dim_b) + M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 @@ -501,12 +501,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 Returns ------- if n_hists == 1: - - gamma : (dim_a, dim_b) ndarray + - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - - ot_distance : (n_hists,) ndarray + - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` @@ -538,17 +538,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 ot.optim.cg : General regularized OT """ - - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: - a = np.ones(dim_a, dtype=np.float64) / dim_a + a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: - b = np.ones(dim_b, dtype=np.float64) / dim_b + b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] @@ -561,56 +559,52 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b a = a.reshape(dim_a, 1) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim_a) - beta = np.zeros(dim_b) + alpha = nx.zeros(dim_a, type_as=M) + beta = nx.zeros(dim_b, type_as=M) while (err > stopThr and cpt < numItermax): uprev = u vprev = v - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) if n_hists: f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((a / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) v = ((b / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True if n_hists: - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) else: - alpha = alpha + reg * np.log(np.max(u)) - beta = beta + reg * np.log(np.max(v)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u)) + beta = beta + reg * nx.log(nx.max(v)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -620,8 +614,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 if (cpt % 10 == 0 and not absorbing) or cpt == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), - 1.) + err = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -636,25 +631,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if n_hists: - logu = alpha[:, None] / reg + np.log(u) - logv = beta[:, None] / reg + np.log(v) + logu = alpha[:, None] / reg + nx.log(u) + logv = beta[:, None] / reg + nx.log(v) else: - logu = alpha / reg + np.log(u) - logv = beta / reg + np.log(v) + logu = alpha / reg + nx.log(u) + logv = beta / reg + nx.log(v) if log: log['logu'] = logu log['logv'] = logv if n_hists: # return only loss - res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] + - logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1)) - res = np.exp(res) + res = nx.logsumexp( + nx.log(M + 1e-100)[:, :, None] + + logu[:, None, :] + + logv[None, :, :] + - M[:, :, None] / reg, + axis=(0, 1) + ) + res = nx.exp(res) if log: return res, log else: return res else: # return OT matrix - ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg) + ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg) if log: return ot_matrix, log else: @@ -683,9 +683,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 @@ -693,7 +693,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Marginal relaxation term > 0 tau : float Stabilization threshold for log domain absorption. - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -708,7 +708,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -726,9 +726,12 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) @@ -737,47 +740,43 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, fi = reg_m / (reg_m + reg) - u = np.ones((dim, n_hists)) / dim - v = np.ones((dim, n_hists)) / dim + u = nx.ones((dim, n_hists), type_as=A) / dim + v = nx.ones((dim, n_hists), type_as=A) / dim # print(reg) - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. - alpha = np.zeros(dim) - beta = np.zeros(dim) - q = np.ones(dim) / dim + alpha = nx.zeros(dim, type_as=A) + beta = nx.zeros(dim, type_as=A) + q = nx.ones(dim, type_as=A) / dim for i in range(numItermax): - qprev = q.copy() - Kv = K.dot(v) - f_alpha = np.exp(- alpha / (reg + reg_m)) - f_beta = np.exp(- beta / (reg + reg_m)) + qprev = nx.copy(q) + Kv = nx.dot(K, v) + f_alpha = nx.exp(- alpha / (reg + reg_m)) + f_beta = nx.exp(- beta / (reg + reg_m)) f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((A / (Kv + 1e-16)) ** fi) * f_alpha - Ktu = K.T.dot(u) + Ktu = nx.dot(K.T, u) q = (Ktu ** (1 - fi)) * f_beta - q = q.dot(weights) ** (1 / (1 - fi)) + q = nx.dot(q, weights) ** (1 / (1 - fi)) Q = q[:, None] v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False - if (u > tau).any() or (v > tau).any(): + if nx.any(u > tau) or nx.any(v > tau): absorbing = True - alpha = alpha + reg * np.log(np.max(u, 1)) - beta = beta + reg * np.log(np.max(v, 1)) - K = np.exp((alpha[:, None] + beta[None, :] - - M) / reg) - v = np.ones_like(v) - Kv = K.dot(v) - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + alpha = alpha + reg * nx.log(nx.max(u, 1)) + beta = beta + reg * nx.log(nx.max(v, 1)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(v.shape, type_as=v) + Kv = nx.dot(K, v) + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) @@ -786,8 +785,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations - err = abs(q - qprev).max() / max(abs(q).max(), - abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1. + ) if log: log['err'].append(err) if verbose: @@ -804,8 +804,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, "Or a larger absorption threshold `tau`.") if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -833,15 +833,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -856,7 +856,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -874,40 +874,43 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, """ + A, M = list_to_array(A, M) + nx = get_backend(A, M) + dim, n_hists = A.shape if weights is None: - weights = np.ones(n_hists) / n_hists + weights = nx.ones(n_hists, type_as=A) / n_hists else: assert(len(weights) == A.shape[1]) if log: log = {'err': []} - K = np.exp(- M / reg) + K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) - v = np.ones((dim, n_hists)) - u = np.ones((dim, 1)) - q = np.ones(dim) + v = nx.ones((dim, n_hists), type_as=A) + u = nx.ones((dim, 1), type_as=A) + q = nx.ones(dim, type_as=A) err = 1. for i in range(numItermax): - uprev = u.copy() - vprev = v.copy() - qprev = q.copy() + uprev = nx.copy(u) + vprev = nx.copy(v) + qprev = nx.copy(q) - Kv = K.dot(v) + Kv = nx.dot(K, v) u = (A / Kv) ** fi - Ktu = K.T.dot(u) - q = ((Ktu ** (1 - fi)).dot(weights)) + Ktu = nx.dot(K.T, u) + q = nx.dot(Ktu ** (1 - fi), weights) q = q ** (1 / (1 - fi)) Q = q[:, None] v = (Q / Ktu) ** fi - if (np.any(Ktu == 0.) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) @@ -916,8 +919,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, q = qprev break # compute change in barycenter - err = abs(q - qprev).max() - err /= max(abs(q).max(), abs(qprev).max(), 1.) + err = nx.max(nx.abs(q - qprev)) / max( + nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0 + ) if log: log['err'].append(err) # if barycenter did not change + at least 10 iterations - stop @@ -932,8 +936,8 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, if log: log['niter'] = i - log['logu'] = np.log(u + 1e-300) - log['logv'] = np.log(v + 1e-300) + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) return q, log else: return q @@ -961,15 +965,15 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Parameters ---------- - A : np.ndarray (dim, n_hists) + A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` - M : np.ndarray (dim, dim) + M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 - weights : np.ndarray (n_hists,) optional + weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional @@ -984,7 +988,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, Returns ------- - a : (dim,) ndarray + a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters diff --git a/ot/utils.py b/ot/utils.py index 725ca00..a23ce7e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature -from .backend import get_backend +from .backend import get_backend, Backend __time_tic_toc = time.time() @@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): def laplacian(x): r"""Compute Laplacian matrix""" - L = np.diag(np.sum(x, axis=0)) - x + nx = get_backend(x) + L = nx.diag(nx.sum(x, axis=0)) - x return L @@ -136,7 +137,7 @@ def unif(n, type_as=None): return np.ones((n,)) / n else: nx = get_backend(type_as) - return nx.ones((n,)) / n + return nx.ones((n,), type_as=type_as) / n def clean_zeros(a, b, M): @@ -296,7 +297,8 @@ def cost_normalization(C, norm=None): def dots(*args): r""" dots function for multiple matrix multiply """ - return reduce(np.dot, args) + nx = get_backend(*args) + return reduce(nx.dot, args) def label_normalization(y, start=0): @@ -314,8 +316,9 @@ def label_normalization(y, start=0): y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ + nx = get_backend(y) - diff = np.min(np.unique(y)) - start + diff = nx.min(nx.unique(y)) - start if diff != 0: y -= diff return y @@ -482,6 +485,19 @@ class BaseEstimator(object): arguments (no ``*args`` or ``**kwargs``). """ + nx: Backend = None + + def _get_backend(self, *arrays): + nx = get_backend( + *[input_ for input_ in arrays if input_ is not None] + ) + if nx.__name__ in ("jax", "tf"): + raise TypeError( + """JAX or TF arrays have been received but domain + adaptation does not support those backend.""") + self.nx = nx + return nx + @classmethod def _get_param_names(cls): r"""Get parameter names for the estimator""" diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 6a42cfe..20f307a 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -66,9 +66,7 @@ def test_wasserstein_1d(nx): rho_v = np.abs(rng.randn(n)) rho_v /= rho_v.sum() - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) # test 1 : wasserstein_1d should be close to scipy W_1 implementation np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), @@ -98,9 +96,7 @@ def test_wasserstein_1d_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - rho_ub = nx.from_numpy(rho_u, type_as=tp) - rho_vb = nx.from_numpy(rho_v, type_as=tp) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) @@ -122,17 +118,13 @@ def test_wasserstein_1d_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1) nx.assert_same_dtype_device(xb, res) assert nx.dtype_device(res)[1].startswith("GPU") @@ -190,9 +182,7 @@ def test_emd1d_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - rho_ub = nx.from_numpy(rho_u, type_as=tp) - rho_vb = nx.from_numpy(rho_v, type_as=tp) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) @@ -214,9 +204,7 @@ def test_emd1d_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) nx.assert_same_dtype_device(xb, emd) @@ -224,9 +212,7 @@ def test_emd1d_device_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - rho_ub = nx.from_numpy(rho_u) - rho_vb = nx.from_numpy(rho_v) + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) emd = ot.emd_1d(xb, xb, rho_ub, rho_vb) emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb) nx.assert_same_dtype_device(xb, emd) diff --git a/test/test_backend.py b/test/test_backend.py index 027c4cd..311c075 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -217,6 +217,8 @@ def test_empty_backend(): nx.zero_pad(M, v) with pytest.raises(NotImplementedError): nx.argmax(M) + with pytest.raises(NotImplementedError): + nx.argmin(M) with pytest.raises(NotImplementedError): nx.mean(M) with pytest.raises(NotImplementedError): @@ -264,12 +266,27 @@ def test_empty_backend(): nx.device_type(M) with pytest.raises(NotImplementedError): nx._bench(lambda x: x, M, n_runs=1) + with pytest.raises(NotImplementedError): + nx.solve(M, v) + with pytest.raises(NotImplementedError): + nx.trace(M) + with pytest.raises(NotImplementedError): + nx.inv(M) + with pytest.raises(NotImplementedError): + nx.sqrtm(M) + with pytest.raises(NotImplementedError): + nx.isfinite(M) + with pytest.raises(NotImplementedError): + nx.array_equal(M, M) + with pytest.raises(NotImplementedError): + nx.is_floating_point(M) def test_func_backends(nx): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) + SquareM = rnd.randn(10, 10) v = rnd.randn(3) val = np.array([1.0]) @@ -288,6 +305,7 @@ def test_func_backends(nx): lst_name = [] Mb = nx.from_numpy(M) + SquareMb = nx.from_numpy(SquareM) vb = nx.from_numpy(v) val = nx.from_numpy(val) @@ -467,6 +485,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('argmax') + A = nx.argmin(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argmin') + A = nx.mean(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('mean') @@ -529,7 +551,11 @@ def test_func_backends(nx): A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0) lst_b.append(nx.to_numpy(A)) - lst_name.append('where') + lst_name.append('where (cond, x, y)') + + A = nx.where(nx.from_numpy(np.array([True, False]))) + lst_b.append(nx.to_numpy(nx.stack(A))) + lst_name.append('where (cond)') A = nx.copy(Mb) lst_b.append(nx.to_numpy(A)) @@ -550,15 +576,47 @@ def test_func_backends(nx): nx._bench(lambda x: x, M, n_runs=1) + A = nx.solve(SquareMb, Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('solve') + + A = nx.trace(SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('trace') + + A = nx.inv(SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('matrix inverse') + + A = nx.sqrtm(SquareMb.T @ SquareMb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("matrix square root") + + A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0) + A = nx.isfinite(A) + lst_b.append(nx.to_numpy(A)) + lst_name.append("isfinite") + + assert not nx.array_equal(Mb, vb), "array_equal (shape)" + assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" + assert not nx.array_equal( + Mb, Mb + nx.eye(*list(Mb.shape)) + ), "array_equal (elements) - expected false" + + assert nx.is_floating_point(Mb), "is_floating_point - expected true" + assert not nx.is_floating_point( + nx.from_numpy(np.array([0, 1, 2], dtype=int)) + ), "is_floating_point - expected false" + lst_tot.append(lst_b) lst_np = lst_tot[0] lst_b = lst_tot[1] for a1, a2, name in zip(lst_np, lst_b, lst_name): - if not np.allclose(a1, a2): - print('Assert fail on: ', name) - assert np.allclose(a1, a2, atol=1e-7) + np.testing.assert_allclose( + a2, a1, atol=1e-7, err_msg=f'ASSERT FAILED ON: {name}' + ) def test_random_backends(nx): diff --git a/test/test_bregman.py b/test/test_bregman.py index 1419f9b..6c37984 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -155,8 +155,7 @@ def test_sinkhorn_backends(nx): G = ot.sinkhorn(a, a, M, 1) - ab = nx.from_numpy(a) - M_nx = nx.from_numpy(M) + ab, M_nx = nx.from_numpy(a, M) Gb = ot.sinkhorn(ab, ab, M_nx, 1) @@ -176,8 +175,7 @@ def test_sinkhorn2_backends(nx): G = ot.sinkhorn(a, a, M, 1) - ab = nx.from_numpy(a) - M_nx = nx.from_numpy(M) + ab, M_nx = nx.from_numpy(a, M) Gb = ot.sinkhorn2(ab, ab, M_nx, 1) @@ -260,8 +258,7 @@ def test_sinkhorn_variants(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - M_nx = nx.from_numpy(M) + ub, M_nx = nx.from_numpy(u, M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -298,8 +295,7 @@ def test_sinkhorn_variants_dtype_device(nx, method): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ub = nx.from_numpy(u, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ub, Mb = nx.from_numpy(u, M, type_as=tp) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) @@ -318,8 +314,7 @@ def test_sinkhorn2_variants_dtype_device(nx, method): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ub = nx.from_numpy(u, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ub, Mb = nx.from_numpy(u, M, type_as=tp) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) @@ -337,8 +332,7 @@ def test_sinkhorn2_variants_device_tf(method): # Check that everything stays on the CPU with tf.device("/CPU:0"): - ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + ub, Mb = nx.from_numpy(u, M) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) nx.assert_same_dtype_device(Mb, Gb) @@ -346,8 +340,7 @@ def test_sinkhorn2_variants_device_tf(method): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + ub, Mb = nx.from_numpy(u, M) Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10) lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10) nx.assert_same_dtype_device(Mb, Gb) @@ -370,9 +363,7 @@ def test_sinkhorn_variants_multi_b(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M) + ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -400,9 +391,7 @@ def test_sinkhorn2_variants_multi_b(nx): M = ot.dist(x, x) - ub = nx.from_numpy(u) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M) + ub, bb, M_nx = nx.from_numpy(u, b, M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) @@ -483,9 +472,7 @@ def test_barycenter(nx, method, verbose, warn): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_nx = nx.from_numpy(weights) + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) reg = 1e-2 if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": @@ -523,9 +510,7 @@ def test_barycenter_debiased(nx, method, verbose, warn): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_nx = nx.from_numpy(weights) + A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights) # wasserstein reg = 1e-2 @@ -594,9 +579,7 @@ def test_barycenter_stabilization(nx): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - A_nx = nx.from_numpy(A) - M_nx = nx.from_numpy(M) - weights_b = nx.from_numpy(weights) + A_nx, M_nx, weights_b = nx.from_numpy(A, M, weights) # wasserstein reg = 1e-2 @@ -697,11 +680,7 @@ def test_unmix(nx): M0 /= M0.max() h0 = ot.unif(2) - ab = nx.from_numpy(a) - Db = nx.from_numpy(D) - M_nx = nx.from_numpy(M) - M0b = nx.from_numpy(M0) - h0b = nx.from_numpy(h0) + ab, Db, M_nx, M0b, h0b = nx.from_numpy(a, D, M, M0, h0) # wasserstein reg = 1e-3 @@ -727,12 +706,7 @@ def test_empirical_sinkhorn(nx): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_mb = nx.from_numpy(M_m, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) @@ -776,12 +750,7 @@ def test_lazy_empirical_sinkhorn(nx): M = ot.dist(X_s, X_t) M_m = ot.dist(X_s, X_t, metric='euclidean') - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_mb = nx.from_numpy(M_m, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) @@ -825,19 +794,13 @@ def test_empirical_sinkhorn_divergence(nx): a = np.linspace(1, n, n) a /= a.sum() b = ot.unif(n) - X_s = np.reshape(np.arange(n), (n, 1)) - X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1)) + X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1)) + X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1)) M = ot.dist(X_s, X_t) M_s = ot.dist(X_s, X_s) M_t = ot.dist(X_t, X_t) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - X_sb = nx.from_numpy(X_s) - X_tb = nx.from_numpy(X_t) - M_nx = nx.from_numpy(M, type_as=ab) - M_sb = nx.from_numpy(M_s, type_as=ab) - M_tb = nx.from_numpy(M_t, type_as=ab) + ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t) emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( @@ -872,9 +835,7 @@ def test_stabilized_vs_sinkhorn_multidim(nx): M /= np.median(M) epsilon = 0.1 - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M, type_as=ab) + ab, bb, M_nx = nx.from_numpy(a, b, M) G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, @@ -936,9 +897,7 @@ def test_screenkhorn(nx): x = rng.randn(n, 2) M = ot.dist(x, x) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - M_nx = nx.from_numpy(M, type_as=ab) + ab, bb, M_nx = nx.from_numpy(a, b, M) # sinkhorn G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1)) diff --git a/test/test_da.py b/test/test_da.py index 9f2bb50..4bf0ab1 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -19,7 +19,32 @@ except ImportError: nosklearn = True -def test_sinkhorn_lpl1_transport_class(): +def test_class_jax_tf(): + backends = [] + from ot.backend import jax, tf + if jax: + backends.append(ot.backend.JaxBackend()) + if tf: + backends.append(ot.backend.TensorflowBackend()) + + for nx in backends: + ns = 150 + nt = 200 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + + otda = ot.da.SinkhornLpl1Transport() + + with pytest.raises(TypeError): + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport """ @@ -29,6 +54,8 @@ def test_sinkhorn_lpl1_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.SinkhornLpl1Transport() # test its computed @@ -44,15 +71,15 @@ def test_sinkhorn_lpl1_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -62,7 +89,7 @@ def test_sinkhorn_lpl1_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -85,24 +112,26 @@ def test_sinkhorn_lpl1_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornLpl1Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) assert mass_semi == 0, "semisupervised mode not working" -def test_sinkhorn_l1l2_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_l1l2_transport_class(nx): """test_sinkhorn_transport """ @@ -112,6 +141,8 @@ def test_sinkhorn_l1l2_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.SinkhornL1l2Transport() # test its computed @@ -128,15 +159,15 @@ def test_sinkhorn_l1l2_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -156,7 +187,7 @@ def test_sinkhorn_l1l2_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -169,22 +200,22 @@ def test_sinkhorn_l1l2_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornL1l2Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornL1l2Transport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] - assert_allclose(mass_semi, np.zeros_like(mass_semi), + assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), rtol=1e-9, atol=1e-9) # check everything runs well with log=True @@ -193,7 +224,9 @@ def test_sinkhorn_l1l2_transport_class(): assert len(otda.log_.keys()) != 0 -def test_sinkhorn_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ @@ -203,6 +236,8 @@ def test_sinkhorn_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.SinkhornTransport() # test its computed @@ -219,15 +254,15 @@ def test_sinkhorn_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -247,7 +282,7 @@ def test_sinkhorn_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -260,19 +295,19 @@ def test_sinkhorn_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornTransport() otda_unsup.fit(Xs=Xs, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) assert mass_semi == 0, "semisupervised mode not working" @@ -282,7 +317,9 @@ def test_sinkhorn_transport_class(): assert len(otda.log_.keys()) != 0 -def test_unbalanced_sinkhorn_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_unbalanced_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ @@ -292,6 +329,8 @@ def test_unbalanced_sinkhorn_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.UnbalancedSinkhornTransport() # test its computed @@ -318,7 +357,7 @@ def test_unbalanced_sinkhorn_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -328,7 +367,7 @@ def test_unbalanced_sinkhorn_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -341,12 +380,12 @@ def test_unbalanced_sinkhorn_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornTransport() otda_unsup.fit(Xs=Xs, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.SinkhornTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" @@ -357,7 +396,9 @@ def test_unbalanced_sinkhorn_transport_class(): assert len(otda.log_.keys()) != 0 -def test_emd_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_emd_transport_class(nx): """test_sinkhorn_transport """ @@ -367,6 +408,8 @@ def test_emd_transport_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.EMDTransport() # test its computed @@ -382,15 +425,15 @@ def test_emd_transport_class(): mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -410,7 +453,7 @@ def test_emd_transport_class(): assert_equal(transp_ys.shape[0], ys.shape[0]) assert_equal(transp_ys.shape[1], len(np.unique(yt))) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -423,28 +466,32 @@ def test_emd_transport_class(): # test unsupervised vs semi-supervised mode otda_unsup = ot.da.EMDTransport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) - n_unsup = np.sum(otda_unsup.cost_) + n_unsup = nx.sum(otda_unsup.cost_) otda_semi = ot.da.EMDTransport() otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt) assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0]))) - n_semisup = np.sum(otda_semi.cost_) + n_semisup = nx.sum(otda_semi.cost_) # check that the cost matrix norms are indeed different assert n_unsup != n_semisup, "semisupervised mode not working" # check that the coupling forbids mass transport between labeled source # and labeled target samples - mass_semi = np.sum( + mass_semi = nx.sum( otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]) mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max] # we need to use a small tolerance here, otherwise the test breaks - assert_allclose(mass_semi, np.zeros_like(mass_semi), + assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)), rtol=1e-2, atol=1e-2) -def test_mapping_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +@pytest.mark.parametrize("kernel", ["linear", "gaussian"]) +@pytest.mark.parametrize("bias", ["unbiased", "biased"]) +def test_mapping_transport_class(nx, kernel, bias): """test_mapping_transport """ @@ -455,101 +502,29 @@ def test_mapping_transport_class(): Xt, yt = make_data_classif('3gauss2', nt) Xs_new, _ = make_data_classif('3gauss', ns + 1) - ########################################################################## - # kernel == linear mapping tests - ########################################################################## + Xs, Xt, Xs_new = nx.from_numpy(Xs, Xt, Xs_new) - # check computation and dimensions if bias == False - otda = ot.da.MappingTransport(kernel="linear", bias=False) + # Mapping tests + bias = bias == "biased" + otda = ot.da.MappingTransport(kernel=kernel, bias=bias) otda.fit(Xs=Xs, Xt=Xt) assert hasattr(otda, "coupling_") assert hasattr(otda, "mapping_") assert hasattr(otda, "log_") assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[1], Xt.shape[1]))) + S = Xs.shape[0] if kernel == "gaussian" else Xs.shape[1] # if linear + if bias: + S += 1 + assert_equal(otda.mapping_.shape, ((S, Xt.shape[1]))) # test margin constraints mu_s = unif(ns) mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) - - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - # check computation and dimensions if bias == True - otda = ot.da.MappingTransport(kernel="linear", bias=True) - otda.fit(Xs=Xs, Xt=Xt) - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1]))) - - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) - - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - ########################################################################## - # kernel == gaussian mapping tests - ########################################################################## - - # check computation and dimensions if bias == False - otda = ot.da.MappingTransport(kernel="gaussian", bias=False) - otda.fit(Xs=Xs, Xt=Xt) - - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[0], Xt.shape[1]))) - - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - assert_equal(transp_Xs.shape, Xs.shape) - - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - # check computation and dimensions if bias == True - otda = ot.da.MappingTransport(kernel="gaussian", bias=True) - otda.fit(Xs=Xs, Xt=Xt) - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - assert_equal(otda.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1]))) - - # test margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -561,29 +536,39 @@ def test_mapping_transport_class(): assert_equal(transp_Xs_new.shape, Xs_new.shape) # check everything runs well with log=True - otda = ot.da.MappingTransport(kernel="gaussian", log=True) + otda = ot.da.MappingTransport(kernel=kernel, bias=bias, log=True) otda.fit(Xs=Xs, Xt=Xt) assert len(otda.log_.keys()) != 0 + +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_mapping_transport_class_specific_seed(nx): # check that it does not crash when derphi is very close to 0 + ns = 20 + nt = 30 np.random.seed(39) Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) otda = ot.da.MappingTransport(kernel="gaussian", bias=False) - otda.fit(Xs=Xs, Xt=Xt) + otda.fit(Xs=nx.from_numpy(Xs), Xt=nx.from_numpy(Xt)) np.random.seed(None) -def test_linear_mapping(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_linear_mapping(nx): ns = 150 nt = 200 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) - A, b = ot.da.OT_mapping_linear(Xs, Xt) + Xsb, Xtb = nx.from_numpy(Xs, Xt) - Xst = Xs.dot(A) + b + A, b = ot.da.OT_mapping_linear(Xsb, Xtb) + + Xst = nx.to_numpy(nx.dot(Xsb, A) + b) Ct = np.cov(Xt.T) Cst = np.cov(Xst.T) @@ -591,22 +576,26 @@ def test_linear_mapping(): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) -def test_linear_mapping_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_linear_mapping_class(nx): ns = 150 nt = 200 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xsb, Xtb = nx.from_numpy(Xs, Xt) + otmap = ot.da.LinearTransport() - otmap.fit(Xs=Xs, Xt=Xt) + otmap.fit(Xs=Xsb, Xt=Xtb) assert hasattr(otmap, "A_") assert hasattr(otmap, "B_") assert hasattr(otmap, "A1_") assert hasattr(otmap, "B1_") - Xst = otmap.transform(Xs=Xs) + Xst = nx.to_numpy(otmap.transform(Xs=Xsb)) Ct = np.cov(Xt.T) Cst = np.cov(Xst.T) @@ -614,7 +603,9 @@ def test_linear_mapping_class(): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) -def test_jcpot_transport_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_jcpot_transport_class(nx): """test_jcpot_transport """ @@ -627,6 +618,8 @@ def test_jcpot_transport_class(): Xt, yt = make_data_classif('3gauss2', nt) + Xs1, ys1, Xs2, ys2, Xt, yt = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt, yt) + Xs = [Xs1, Xs2] ys = [ys1, ys2] @@ -649,19 +642,24 @@ def test_jcpot_transport_class(): for i in range(len(Xs)): # test margin constraints w.r.t. uniform target weights for each coupling matrix assert_allclose( - np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_[i], axis=0)), mu_t, rtol=1e-3, atol=1e-3) # test margin constraints w.r.t. modified source weights for each source domain assert_allclose( - np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, - atol=1e-3) + nx.to_numpy( + nx.dot(otda.log_['D1'][i], nx.sum(otda.coupling_[i], axis=1)) + ), + nx.to_numpy(otda.proportions_), + rtol=1e-3, + atol=1e-3 + ) # test transform transp_Xs = otda.transform(Xs=Xs) [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new, _ = make_data_classif('3gauss', ns1 + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns1 + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -670,15 +668,16 @@ def test_jcpot_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) - assert_equal(transp_yt.shape[1], len(np.unique(ys))) + assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(*ys)))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) - [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)] - [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)] + for x, y in zip(transp_ys, ys): + assert_equal(x.shape[0], y.shape[0]) + assert_equal(x.shape[1], len(np.unique(nx.to_numpy(y)))) -def test_jcpot_barycenter(): +def test_jcpot_barycenter(nx): """test_jcpot_barycenter """ @@ -695,19 +694,23 @@ def test_jcpot_barycenter(): Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1) Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2) - Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) + Xt, _ = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) - Xs = [Xs1, Xs2] - ys = [ys1, ys2] + Xs1b, ys1b, Xs2b, ys2b, Xtb = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt) - prop = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean', + Xsb = [Xs1b, Xs2b] + ysb = [ys1b, ys2b] + + prop = ot.bregman.jcpot_barycenter(Xsb, ysb, Xtb, reg=.5, metric='sqeuclidean', numItermax=10000, stopThr=1e-9, verbose=False, log=False) - np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(nx.to_numpy(prop), [1 - pt, pt], rtol=1e-3, atol=1e-3) @pytest.mark.skipif(nosklearn, reason="No sklearn available") -def test_emd_laplace_class(): +@pytest.skip_backend("jax") +@pytest.skip_backend("tf") +def test_emd_laplace_class(nx): """test_emd_laplace_transport """ ns = 150 @@ -716,6 +719,8 @@ def test_emd_laplace_class(): Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=True) # test its computed @@ -732,15 +737,15 @@ def test_emd_laplace_class(): mu_t = unif(nt) assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3) assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - Xs_new, _ = make_data_classif('3gauss', ns + 1) + Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0]) transp_Xs_new = otda.transform(Xs_new) # check that the oos method is working @@ -750,7 +755,7 @@ def test_emd_laplace_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) - Xt_new, _ = make_data_classif('3gauss2', nt + 1) + Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0]) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) # check that the oos method is working @@ -763,9 +768,9 @@ def test_emd_laplace_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) - assert_equal(transp_yt.shape[1], len(np.unique(ys))) + assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(ys)))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) - assert_equal(transp_ys.shape[1], len(np.unique(yt))) + assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt)))) diff --git a/test/test_gromov.py b/test/test_gromov.py index 0dcf2da..12fd2b9 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -35,11 +35,7 @@ def test_gromov(nx): C1 /= C1.max() C2 /= C2.max() - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) - G0b = nx.from_numpy(G0) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True) Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)) @@ -105,11 +101,7 @@ def test_gromov_dtype_device(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - C1b = nx.from_numpy(C1, type_as=tp) - C2b = nx.from_numpy(C2, type_as=tp) - pb = nx.from_numpy(p, type_as=tp) - qb = nx.from_numpy(q, type_as=tp) - G0b = nx.from_numpy(G0, type_as=tp) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) @@ -136,11 +128,7 @@ def test_gromov_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) - G0b = nx.from_numpy(G0) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False) nx.assert_same_dtype_device(C1b, Gb) @@ -148,11 +136,7 @@ def test_gromov_device_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) - G0b = nx.from_numpy(G0b) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True) gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False) nx.assert_same_dtype_device(C1b, Gb) @@ -222,10 +206,7 @@ def test_entropic_gromov(nx): C1 /= C1.max() C2 /= C2.max() - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) @@ -285,10 +266,7 @@ def test_entropic_gromov_dtype_device(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - C1b = nx.from_numpy(C1, type_as=tp) - C2b = nx.from_numpy(C2, type_as=tp) - pb = nx.from_numpy(p, type_as=tp) - qb = nx.from_numpy(q, type_as=tp) + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp) Gb = ot.gromov.entropic_gromov_wasserstein( C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True @@ -320,10 +298,7 @@ def test_pointwise_gromov(nx): C1 /= C1.max() C2 /= C2.max() - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) def loss(x, y): return np.abs(x - y) @@ -381,10 +356,7 @@ def test_sampled_gromov(nx): C1 /= C1.max() C2 /= C2.max() - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) + C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q) def loss(x, y): return np.abs(x - y) @@ -423,19 +395,15 @@ def test_gromov_barycenter(nx): n_samples = 3 p = ot.unif(n_samples) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - p1b = nx.from_numpy(p1) - p2b = nx.from_numpy(p2) - pb = nx.from_numpy(p) + C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) Cb = ot.gromov.gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42 ) Cbb = nx.to_numpy(ot.gromov.gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42 + 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42 )) np.testing.assert_allclose(Cb, Cbb, atol=1e-06) np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) @@ -443,15 +411,15 @@ def test_gromov_barycenter(nx): # test of gromov_barycenters with `log` on Cb_, err_ = ot.gromov.gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True ) Cbb_, errb_ = ot.gromov.gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], errb_['err']) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) Cb2 = ot.gromov.gromov_barycenters( @@ -468,15 +436,15 @@ def test_gromov_barycenter(nx): # test of gromov_barycenters with `log` on Cb2_, err2_ = ot.gromov.gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True ) Cb2b_, err2b_ = ot.gromov.gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5], - 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True + 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) - np.testing.assert_array_almost_equal(err2_['err'], err2_['err']) + np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) @@ -495,11 +463,7 @@ def test_gromov_entropic_barycenter(nx): n_samples = 2 p = ot.unif(n_samples) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - p1b = nx.from_numpy(p1) - p2b = nx.from_numpy(p2) - pb = nx.from_numpy(p) + C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p) Cb = ot.gromov.entropic_gromov_barycenters( n_samples, [C1, C2], [p1, p2], p, [.5, .5], @@ -523,7 +487,7 @@ def test_gromov_entropic_barycenter(nx): ) Cbb_ = nx.to_numpy(Cbb_) np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], errb_['err']) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) Cb2 = ot.gromov.entropic_gromov_barycenters( @@ -548,7 +512,7 @@ def test_gromov_entropic_barycenter(nx): ) Cb2b_ = nx.to_numpy(Cb2b_) np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) - np.testing.assert_array_almost_equal(err2_['err'], err2_['err']) + np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) @@ -578,12 +542,7 @@ def test_fgw(nx): M = ot.dist(ys, yt) M /= M.max() - Mb = nx.from_numpy(M) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - pb = nx.from_numpy(p) - qb = nx.from_numpy(q) - G0b = nx.from_numpy(G0) + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True) Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True) @@ -681,13 +640,7 @@ def test_fgw_barycenter(nx): n_samples = 3 p = ot.unif(n_samples) - ysb = nx.from_numpy(ys) - ytb = nx.from_numpy(yt) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - p1b = nx.from_numpy(p1) - p2b = nx.from_numpy(p2) - pb = nx.from_numpy(p) + ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) Xb, Cb = ot.gromov.fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, @@ -731,10 +684,8 @@ def test_gromov_wasserstein_linear_unmixing(nx): Cdict = np.stack([C1, C2]) p = ot.unif(n) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - Cdictb = nx.from_numpy(Cdict) - pb = nx.from_numpy(p) + C1b, C2b, Cdictb, pb = nx.from_numpy(C1, C2, Cdict, p) + tol = 10**(-5) # Tests without regularization reg = 0. @@ -764,8 +715,8 @@ def test_gromov_wasserstein_linear_unmixing(nx): np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) - np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) - np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) @@ -798,8 +749,8 @@ def test_gromov_wasserstein_linear_unmixing(nx): np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) - np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) - np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) @@ -824,13 +775,14 @@ def test_gromov_wasserstein_dictionary_learning(nx): dataset_means = [C.mean() for C in Cs] np.random.seed(0) Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape)) + if projection == 'nonnegative_symmetric': Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1))) Cdict_init[Cdict_init < 0.] = 0. - Csb = [nx.from_numpy(C) for C in Cs] - psb = [nx.from_numpy(p) for p in ps] - qb = nx.from_numpy(q) - Cdict_initb = nx.from_numpy(Cdict_init) + + Csb = nx.from_numpy(*Cs) + psb = nx.from_numpy(*ps) + qb, Cdict_initb = nx.from_numpy(q, Cdict_init) # Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization # > Compute initial reconstruction of samples on this random dictionary without backend @@ -882,6 +834,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): ) total_reconstruction_b += reconstruction + total_reconstruction_b = nx.to_numpy(total_reconstruction_b) np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) @@ -924,6 +877,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): ) total_reconstruction_b_bis += reconstruction + total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03) @@ -969,6 +923,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): ) total_reconstruction_b_bis2 += reconstruction + total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05) @@ -985,12 +940,8 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): Ydict = np.stack([F, F]) p = ot.unif(n) - C1b = nx.from_numpy(C1) - C2b = nx.from_numpy(C2) - Fb = nx.from_numpy(F) - Cdictb = nx.from_numpy(Cdict) - Ydictb = nx.from_numpy(Ydict) - pb = nx.from_numpy(p) + C1b, C2b, Fb, Cdictb, Ydictb, pb = nx.from_numpy(C1, C2, F, Cdict, Ydict, p) + # Tests without regularization reg = 0. @@ -1022,8 +973,8 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) - np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) - np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) @@ -1058,8 +1009,8 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03) - np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06) - np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06) + np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) + np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06) np.testing.assert_allclose(C1b_emb.shape, (n, n)) np.testing.assert_allclose(C2b_emb.shape, (n, n)) @@ -1093,12 +1044,10 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys]) Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2)) - Csb = [nx.from_numpy(C) for C in Cs] - Ysb = [nx.from_numpy(Y) for Y in Ys] - psb = [nx.from_numpy(p) for p in ps] - qb = nx.from_numpy(q) - Cdict_initb = nx.from_numpy(Cdict_init) - Ydict_initb = nx.from_numpy(Ydict_init) + Csb = nx.from_numpy(*Cs) + Ysb = nx.from_numpy(*Ys) + psb = nx.from_numpy(*ps) + qb, Cdict_initb, Ydict_initb = nx.from_numpy(q, Cdict_init, Ydict_init) # Test: Compute initial reconstruction of samples on this random dictionary alpha = 0.5 @@ -1151,6 +1100,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): ) total_reconstruction_b += reconstruction + total_reconstruction_b = nx.to_numpy(total_reconstruction_b) np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction) np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05) np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03) @@ -1192,6 +1142,8 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 ) total_reconstruction_b_bis += reconstruction + + total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis) np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05) # Test: without using adam optimizer, with log and verbose set to True @@ -1237,4 +1189,5 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): total_reconstruction_b_bis2 += reconstruction # > Compare results with/without backend + total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2) np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05) diff --git a/test/test_optim.py b/test/test_optim.py index 41f9cbe..67e9d13 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -32,9 +32,7 @@ def test_conditional_gradient(nx): def fb(G): return 0.5 * nx.sum(G ** 2) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + ab, bb, Mb = nx.from_numpy(a, b, M) reg = 1e-1 @@ -74,9 +72,7 @@ def test_conditional_gradient_itermax(nx): def fb(G): return 0.5 * nx.sum(G ** 2) - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + ab, bb, Mb = nx.from_numpy(a, b, M) reg = 1e-1 @@ -118,9 +114,7 @@ def test_generalized_conditional_gradient(nx): reg1 = 1e-3 reg2 = 1e-1 - ab = nx.from_numpy(a) - bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + ab, bb, Mb = nx.from_numpy(a, b, M) G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True) Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True) @@ -142,9 +136,12 @@ def test_line_search_armijo(nx): pk = np.array([[-0.25, 0.25], [0.25, -0.25]]) gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]]) old_fval = -123 + + xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk) + # Should not throw an exception and return 0. for alpha alpha, a, b = ot.optim.line_search_armijo( - lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval + lambda x: 1, xkb, pkb, gfkb, old_fval ) alpha_np, anp, bnp = ot.optim.line_search_armijo( lambda x: 1, xk, pk, gfk, old_fval diff --git a/test/test_ot.py b/test/test_ot.py index 3e2d845..bb258e2 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -47,8 +47,7 @@ def test_emd_backends(nx): G = ot.emd(a, a, M) - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) @@ -68,8 +67,7 @@ def test_emd2_backends(nx): val = ot.emd2(a, a, M) - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) valb = ot.emd2(ab, ab, Mb) @@ -90,8 +88,7 @@ def test_emd_emd2_types_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ab = nx.from_numpy(a, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ab, Mb = nx.from_numpy(a, M, type_as=tp) Gb = ot.emd(ab, ab, Mb) @@ -117,8 +114,7 @@ def test_emd_emd2_devices_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) w = ot.emd2(ab, ab, Mb) nx.assert_same_dtype_device(Mb, Gb) @@ -126,8 +122,7 @@ def test_emd_emd2_devices_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) w = ot.emd2(ab, ab, Mb) nx.assert_same_dtype_device(Mb, Gb) @@ -310,8 +305,8 @@ def test_free_support_barycenter_backends(nx): X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) - measures_locations2 = [nx.from_numpy(x) for x in measures_locations] - measures_weights2 = [nx.from_numpy(x) for x in measures_weights] + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) X_init2 = nx.from_numpy(X_init) X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2) diff --git a/test/test_sliced.py b/test/test_sliced.py index 91e0961..08ab4fb 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -123,9 +123,7 @@ def test_sliced_backend(nx): n_projections = 20 - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) val0 = ot.sliced_wasserstein_distance(x, y, projections=P) @@ -153,9 +151,7 @@ def test_sliced_backend_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - yb = nx.from_numpy(y, type_as=tp) - Pb = nx.from_numpy(P, type_as=tp) + xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -174,17 +170,13 @@ def test_sliced_backend_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") @@ -203,9 +195,7 @@ def test_max_sliced_backend(nx): n_projections = 20 - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) @@ -233,9 +223,7 @@ def test_max_sliced_backend_type_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - xb = nx.from_numpy(x, type_as=tp) - yb = nx.from_numpy(y, type_as=tp) - Pb = nx.from_numpy(P, type_as=tp) + xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) @@ -254,17 +242,13 @@ def test_max_sliced_backend_device_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - xb = nx.from_numpy(x) - yb = nx.from_numpy(y) - Pb = nx.from_numpy(P) + xb, yb, Pb = nx.from_numpy(x, y, P) valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index e8349d1..db59504 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -9,11 +9,9 @@ import ot import pytest from ot.unbalanced import barycenter_unbalanced -from scipy.special import logsumexp - @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_convergence(method): +def test_unbalanced_convergence(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -28,36 +26,51 @@ def test_unbalanced_convergence(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True) - loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method=method, - verbose=True) + loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method=method, verbose=True + )) # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16) - logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) - logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16) + logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1) + logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) # check if sinkhorn_unbalanced2 returns the correct loss - np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5) + np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5) + + # check in case no histogram is provided + M_np = nx.to_numpy(M) + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G = ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + G_np = ot.unbalanced.sinkhorn_unbalanced( + a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True + ) + np.testing.assert_allclose(G_np, nx.to_numpy(G)) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_multiple_inputs(method): +def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100 rng = np.random.RandomState(42) @@ -72,6 +85,8 @@ def test_unbalanced_multiple_inputs(method): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, reg_m=reg_m, method=method, @@ -80,23 +95,24 @@ def test_unbalanced_multiple_inputs(method): # check fixed point equations # in log-domain fi = reg_m / (reg_m + epsilon) - logb = np.log(b + 1e-16) - loga = np.log(a + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logb = nx.log(b + 1e-16) + loga = nx.log(a + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logb - logKtu) u_final = fi * (loga - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) assert len(loss) == b.shape[1] -def test_stabilized_vs_sinkhorn(): +def test_stabilized_vs_sinkhorn(nx): # test if stable version matches sinkhorn n = 100 @@ -112,19 +128,27 @@ def test_stabilized_vs_sinkhorn(): M /= np.median(M) epsilon = 0.1 reg_m = 1. - G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon, - method="sinkhorn_stabilized", - reg_m=reg_m, - log=True, - verbose=True) - G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method="sinkhorn", log=True) + + ab, bb, Mb = nx.from_numpy(a, b, M) + + G, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True + ) + G2, _ = ot.unbalanced.sinkhorn_unbalanced2( + ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True + ) + G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method="sinkhorn", log=True + ) + G = nx.to_numpy(G) + G2 = nx.to_numpy(G2) np.testing.assert_allclose(G, G2, atol=1e-5) + np.testing.assert_allclose(G2, G2_np, atol=1e-5) @pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_unbalanced_barycenter(method): +def test_unbalanced_barycenter(nx, method): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -138,25 +162,29 @@ def test_unbalanced_barycenter(method): epsilon = 1. reg_m = 1. - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method=method, log=True, verbose=True) + A, M = nx.from_numpy(A, M) + + q, log = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True + ) # check fixed point equations fi = reg_m / (reg_m + epsilon) - logA = np.log(A + 1e-16) - logq = np.log(q + 1e-16)[:, None] - logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon, - axis=0) - logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) + logA = nx.log(A + 1e-16) + logq = nx.log(q + 1e-16)[:, None] + logKtu = nx.logsumexp( + log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0 + ) + logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1) v_final = fi * (logq - logKtu) u_final = fi * (logA - logKv) np.testing.assert_allclose( - u_final, log["logu"], atol=1e-05) + nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05) np.testing.assert_allclose( - v_final, log["logv"], atol=1e-05) + nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05) -def test_barycenter_stabilized_vs_sinkhorn(): +def test_barycenter_stabilized_vs_sinkhorn(nx): # test generalized sinkhorn for unbalanced OT barycenter n = 100 rng = np.random.RandomState(42) @@ -170,21 +198,24 @@ def test_barycenter_stabilized_vs_sinkhorn(): epsilon = 0.5 reg_m = 10 - qstable, log = barycenter_unbalanced(A, M, reg=epsilon, - reg_m=reg_m, log=True, - tau=100, - method="sinkhorn_stabilized", - verbose=True - ) - q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, - method="sinkhorn", - log=True) + Ab, Mb = nx.from_numpy(A, M) - np.testing.assert_allclose( - q, qstable, atol=1e-05) + qstable, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100, + method="sinkhorn_stabilized", verbose=True + ) + q, _ = barycenter_unbalanced( + Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q_np, _ = barycenter_unbalanced( + A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True + ) + q, qstable = nx.to_numpy(q, qstable) + np.testing.assert_allclose(q, qstable, atol=1e-05) + np.testing.assert_allclose(q, q_np, atol=1e-05) -def test_wrong_method(): +def test_wrong_method(nx): n = 10 rng = np.random.RandomState(42) @@ -199,19 +230,20 @@ def test_wrong_method(): epsilon = 1. reg_m = 1. + a, b, M = nx.from_numpy(a, b, M) + with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, - reg_m=reg_m, - method='badmethod', - log=True, - verbose=True) + ot.unbalanced.sinkhorn_unbalanced( + a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod', + log=True, verbose=True + ) with pytest.raises(ValueError): - ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m, - method='badmethod', - verbose=True) + ot.unbalanced.sinkhorn_unbalanced2( + a, b, M, epsilon, reg_m, method='badmethod', verbose=True + ) -def test_implemented_methods(): +def test_implemented_methods(nx): IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized'] TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling'] NOT_VALID_TOKENS = ['foo'] @@ -228,6 +260,9 @@ def test_implemented_methods(): M = ot.dist(x, x) epsilon = 1. reg_m = 1. + + a, b, M, A = nx.from_numpy(a, b, M, A) + for method in IMPLEMENTED_METHODS: ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m, method=method) diff --git a/test/test_weak.py b/test/test_weak.py index c4c3278..945efb1 100644 --- a/test/test_weak.py +++ b/test/test_weak.py @@ -45,9 +45,7 @@ def test_weak_ot_bakends(nx): G = ot.weak_optimal_transport(xs, xt, u, u) - xs2 = nx.from_numpy(xs) - xt2 = nx.from_numpy(xt) - u2 = nx.from_numpy(u) + xs2, xt2, u2 = nx.from_numpy(xs, xt, u) G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2) -- cgit v1.2.3 From 82452e0f5f6dae05c7a1cc384e7a1fb62ae7e0d5 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 24 Mar 2022 14:13:25 +0100 Subject: [MRG] Add factored coupling (#358) * add gfactored ot * pep8 and add doc * add exmaple for factotred OT * final number of PR * correct test on backends * remove useless loss * better tests --- README.md | 4 +- RELEASES.md | 1 + docs/source/all.rst | 1 + examples/others/plot_factored_coupling.py | 86 ++++++++++++++++++ ot/__init__.py | 5 ++ ot/factored.py | 145 ++++++++++++++++++++++++++++++ ot/plot.py | 7 +- test/test_factored.py | 56 ++++++++++++ 8 files changed, 303 insertions(+), 2 deletions(-) create mode 100644 examples/others/plot_factored_coupling.py create mode 100644 ot/factored.py create mode 100644 test/test_factored.py diff --git a/README.md b/README.md index c6bfd9c..ec5d221 100644 --- a/README.md +++ b/README.md @@ -305,4 +305,6 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020 [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. -[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. \ No newline at end of file +[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. + +[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index 86b401a..c2bd0d1 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features +- Implementation of factored OT with emd and sinkhorn (PR #358). - A brand new logo for POT (PR #357) - Better list of related examples in quick start guide with `minigallery` (PR #334). - Add optional log-domain Sinkhorn implementation in WDA to support smaller values diff --git a/docs/source/all.rst b/docs/source/all.rst index 76d2ff5..3f7d029 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -29,6 +29,7 @@ API and modules partial sliced weak + factored .. autosummary:: :toctree: ../modules/generated/ diff --git a/examples/others/plot_factored_coupling.py b/examples/others/plot_factored_coupling.py new file mode 100644 index 0000000..b5b1c9f --- /dev/null +++ b/examples/others/plot_factored_coupling.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +========================================== +Optimal transport with factored couplings +========================================== + +Illustration of the factored coupling OT between 2D empirical distributions + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot + +# %% +# Generate data an plot it +# ------------------------ + +# parameters and data generation + +np.random.seed(42) + +n = 100 # nb samples + +xs = np.random.rand(n, 2) - .5 + +xs = xs + np.sign(xs) + +xt = np.random.rand(n, 2) - .5 + +a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples + +#%% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + + +# %% +# Compute Factore OT and exact OT solutions +# -------------------------------------- + +#%% EMD +M = ot.dist(xs, xt) +G0 = ot.emd(a, b, M) + +#%% factored OT OT + +Ga, Gb, xb = ot.factored_optimal_transport(xs, xt, a, b, r=4) + + +# %% +# Plot factored OT and exact OT solutions +# -------------------------------------- + +pl.figure(2, (14, 4)) + +pl.subplot(1, 3, 1) +ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Exact OT with samples') + +pl.subplot(1, 3, 2) +ot.plot.plot2D_samples_mat(xs, xb, Ga, c=[.6, .6, .9], alpha=0.5) +ot.plot.plot2D_samples_mat(xb, xt, Gb, c=[.9, .6, .6], alpha=0.5) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.plot(xb[:, 0], xb[:, 1], 'og', label='Template samples') +pl.title('Factored OT with template samples') + +pl.subplot(1, 3, 3) +ot.plot.plot2D_samples_mat(xs, xt, Ga.dot(Gb), c=[.2, .2, .2], alpha=0.1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.title('Factored OT low rank OT plan') diff --git a/ot/__init__.py b/ot/__init__.py index bda7a35..c5e1967 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -33,6 +33,7 @@ from . import partial from . import backend from . import regpath from . import weak +from . import factored # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -44,6 +45,9 @@ from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport +from .factored import factored_optimal_transport + + # utils functions from .utils import dist, unif, tic, toc, toq @@ -57,4 +61,5 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', + 'factored_optimal_transport', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/factored.py b/ot/factored.py new file mode 100644 index 0000000..abc2445 --- /dev/null +++ b/ot/factored.py @@ -0,0 +1,145 @@ +""" +Factored OT solvers (low rank, cost or OT plan) +""" + +# Author: Remi Flamary +# +# License: MIT License + +from .backend import get_backend +from .utils import dist +from .lp import emd +from .bregman import sinkhorn + +__all__ = ['factored_optimal_transport'] + + +def factored_optimal_transport(Xa, Xb, a=None, b=None, reg=0.0, r=100, X0=None, stopThr=1e-7, numItermax=100, verbose=False, log=False, **kwargs): + r"""Solves factored OT problem and return OT plans and intermediate distribution + + This function solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where : + + - :math:`\mu_a` and :math:`\mu_b` are empirical distributions. + - :math:`\mu` is an empirical distribution with r samples + + And returns the two OT plans between + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. But the algorithm uses the C++ CPU backend + which can lead to copy overhead on GPU arrays. + + Uses the conditional gradient algorithm to solve the problem proposed in + :ref:`[39] `. + + Parameters + ---------- + Xa : (ns,d) array-like, float + Source samples + Xb : (nt,d) array-like, float + Target samples + a : (ns,) array-like, float + Source histogram (uniform weight if empty list) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list)) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on the relative variation (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + Ga: array-like, shape (ns, r) + Optimal transportation matrix between source and the intermediate + distribution + Gb: array-like, shape (r, nt) + Optimal transportation matrix between the intermediate and target + distribution + X: array-like, shape (r, d) + Support of the intermediate distribution + log: dict, optional + If input log is true, a dictionary containing the cost and dual + variables and exit status + + + .. _references-factored: + References + ---------- + .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, + G., & Weed, J. (2019, April). Statistical optimal transport via factored + couplings. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2454-2465). PMLR. + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General + regularized OT + """ + + nx = get_backend(Xa, Xb) + + n_a = Xa.shape[0] + n_b = Xb.shape[0] + d = Xa.shape[1] + + if a is None: + a = nx.ones((n_a), type_as=Xa) / n_a + if b is None: + b = nx.ones((n_b), type_as=Xb) / n_b + + if X0 is None: + X = nx.randn(r, d, type_as=Xa) + else: + X = X0 + + w = nx.ones(r, type_as=Xa) / r + + def solve_ot(X1, X2, w1, w2): + M = dist(X1, X2) + if reg > 0: + G, log = sinkhorn(w1, w2, M, reg, log=True, **kwargs) + log['cost'] = nx.sum(G * M) + return G, log + else: + return emd(w1, w2, M, log=True, **kwargs) + + norm_delta = [] + + # solve the barycenter + for i in range(numItermax): + + old_X = X + + # solve OT with template + Ga, loga = solve_ot(Xa, X, a, w) + Gb, logb = solve_ot(X, Xb, w, b) + + X = 0.5 * (nx.dot(Ga.T, Xa) + nx.dot(Gb, Xb)) * r + + delta = nx.norm(X - old_X) + if delta < stopThr: + break + if log: + norm_delta.append(delta) + + if log: + log_dic = {'delta_iter': norm_delta, + 'ua': loga['u'], + 'va': loga['v'], + 'ub': logb['u'], + 'vb': logb['v'], + 'costa': loga['cost'], + 'costb': logb['cost'], + } + return Ga, Gb, X, log_dic + + return Ga, Gb, X diff --git a/ot/plot.py b/ot/plot.py index 2208c90..8ade2eb 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): if ('color' not in kwargs) and ('c' not in kwargs): kwargs['color'] = 'k' mx = G.max() + if 'alpha' in kwargs: + scale = kwargs['alpha'] + del kwargs['alpha'] + else: + scale = 1 for i in range(xs.shape[0]): for j in range(xt.shape[0]): if G[i, j] / mx > thr: pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], - alpha=G[i, j] / mx, **kwargs) + alpha=G[i, j] / mx * scale, **kwargs) diff --git a/test/test_factored.py b/test/test_factored.py new file mode 100644 index 0000000..fd2fd01 --- /dev/null +++ b/test/test_factored.py @@ -0,0 +1,56 @@ +"""Tests for main module ot.weak """ + +# Author: Remi Flamary +# +# License: MIT License + +import ot +import numpy as np + + +def test_factored_ot(): + # test weak ot solver and identity stationary point + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, r=10, log=True) + + # check constraints + np.testing.assert_allclose(u, Ga.sum(1)) + np.testing.assert_allclose(u, Gb.sum(0)) + + Ga, Gb, X, log = ot.factored_optimal_transport(xs, xt, u, u, reg=1, r=10, log=True) + + # check constraints + np.testing.assert_allclose(u, Ga.sum(1)) + np.testing.assert_allclose(u, Gb.sum(0)) + + +def test_factored_ot_backends(nx): + # test weak ot solver for different backends + n = 50 + rng = np.random.RandomState(0) + + xs = rng.randn(n, 2) + xt = rng.randn(n, 2) + u = ot.utils.unif(n) + + xs2 = nx.from_numpy(xs) + xt2 = nx.from_numpy(xt) + u2 = nx.from_numpy(u) + + Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, u2, u2, r=10) + + # check constraints + np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) + np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) + + Ga2, Gb2, X2 = ot.factored_optimal_transport(xs2, xt2, reg=1, r=10, X0=X2) + + # check constraints + np.testing.assert_allclose(u, nx.to_numpy(Ga2).sum(1)) + np.testing.assert_allclose(u, nx.to_numpy(Gb2).sum(0)) -- cgit v1.2.3 From 0afd84d744a472903d427e3c7ae32e55fdd7b9a7 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 4 Apr 2022 10:23:04 +0200 Subject: [WIP] Add backend dual loss and plan computation for stochastic optimization or regularized OT (#360) * add losses and plan computations and exmaple for dual oiptimization * pep8 * add nice exmaple * update awesome example stochasti dual * add all tests * pep8 + speedup exmaple * add release info --- RELEASES.md | 2 + examples/backends/plot_dual_ot_pytorch.py | 168 ++++++++++++++ .../backends/plot_stoch_continuous_ot_pytorch.py | 189 ++++++++++++++++ ot/stochastic.py | 242 ++++++++++++++++++++- test/test_stochastic.py | 115 +++++++++- 5 files changed, 713 insertions(+), 3 deletions(-) create mode 100644 examples/backends/plot_dual_ot_pytorch.py create mode 100644 examples/backends/plot_stoch_continuous_ot_pytorch.py diff --git a/RELEASES.md b/RELEASES.md index c2bd0d1..45336f7 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,8 @@ #### New features +- Add stochastic loss and OT plan computation for regularized OT and + backend examples(PR #360). - Implementation of factored OT with emd and sinkhorn (PR #358). - A brand new logo for POT (PR #357) - Better list of related examples in quick start guide with `minigallery` (PR #334). diff --git a/examples/backends/plot_dual_ot_pytorch.py b/examples/backends/plot_dual_ot_pytorch.py new file mode 100644 index 0000000..d3f7a66 --- /dev/null +++ b/examples/backends/plot_dual_ot_pytorch.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +r""" +====================================================================== +Dual OT solvers for entropic and quadratic regularized OT with Pytorch +====================================================================== + + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import torch +import ot +import ot.plot + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) + +n_source_samples = 100 +n_target_samples = 100 +theta = 2 * np.pi / 20 +noise_level = 0.1 + +Xs, ys = ot.datasets.make_data_classif( + 'gaussrot', n_source_samples, nz=noise_level) +Xt, yt = ot.datasets.make_data_classif( + 'gaussrot', n_target_samples, theta=theta, nz=noise_level) + +# one of the target mode changes its variance (no linear mapping) +Xt[yt == 2] *= 3 +Xt = Xt + 4 + + +# %% +# Plot data +# --------- + +pl.figure(1, (10, 5)) +pl.clf() +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples') +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +# %% +# Convert data to torch tensors +# ----------------------------- + +xs = torch.tensor(Xs) +xt = torch.tensor(Xt) + +# %% +# Estimating dual variables for entropic OT +# ----------------------------------------- + +u = torch.randn(n_source_samples, requires_grad=True) +v = torch.randn(n_source_samples, requires_grad=True) + +reg = 0.5 + +optimizer = torch.optim.Adam([u, v], lr=1) + +# number of iteration +n_iter = 200 + + +losses = [] + +for i in range(n_iter): + + # generate noise samples + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(2) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + +Ge = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg) + +# %% +# Plot teh estimated entropic OT plan +# ----------------------------------- + +pl.figure(3, (10, 5)) +pl.clf() +ot.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1) +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.legend(loc=0) +pl.title('Source and target distributions') + + +# %% +# Estimating dual variables for quadratic OT +# ----------------------------------------- + +u = torch.randn(n_source_samples, requires_grad=True) +v = torch.randn(n_source_samples, requires_grad=True) + +reg = 0.01 + +optimizer = torch.optim.Adam([u, v], lr=1) + +# number of iteration +n_iter = 200 + + +losses = [] + + +for i in range(n_iter): + + # generate noise samples + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(4) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + +Gq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg) + + +# %% +# Plot the estimated quadratic OT plan +# ----------------------------------- + +pl.figure(5, (10, 5)) +pl.clf() +ot.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1) +pl.scatter(Xs[:, 0], Xs[:, 1], marker='+', label='Source samples', zorder=2) +pl.scatter(Xt[:, 0], Xt[:, 1], marker='o', label='Target samples', zorder=2) +pl.legend(loc=0) +pl.title('OT plan with quadratic regularization') diff --git a/examples/backends/plot_stoch_continuous_ot_pytorch.py b/examples/backends/plot_stoch_continuous_ot_pytorch.py new file mode 100644 index 0000000..6d9b916 --- /dev/null +++ b/examples/backends/plot_stoch_continuous_ot_pytorch.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +r""" +====================================================================== +Continuous OT plan estimation with Pytorch +====================================================================== + + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import torch +from torch import nn +import ot +import ot.plot + +# %% +# Data generation +# --------------- + +torch.manual_seed(42) +np.random.seed(42) + +n_source_samples = 10000 +n_target_samples = 10000 +theta = 2 * np.pi / 20 +noise_level = 0.1 + +Xs = np.random.randn(n_source_samples, 2) * 0.5 +Xt = np.random.randn(n_target_samples, 2) * 2 + +# one of the target mode changes its variance (no linear mapping) +Xt = Xt + 4 + + +# %% +# Plot data +# --------- +nvisu = 300 +pl.figure(1, (5, 5)) +pl.clf() +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', label='Source samples', alpha=0.5) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', label='Target samples', alpha=0.5) +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Source and target distributions') + +# %% +# Convert data to torch tensors +# ----------------------------- + +xs = torch.tensor(Xs) +xt = torch.tensor(Xt) + +# %% +# Estimating deep dual variables for entropic OT +# ---------------------------------------------- + +torch.manual_seed(42) + +# define the MLP model + + +class Potential(torch.nn.Module): + def __init__(self): + super(Potential, self).__init__() + self.fc1 = nn.Linear(2, 200) + self.fc2 = nn.Linear(200, 1) + self.relu = torch.nn.ReLU() # instead of Heaviside step fn + + def forward(self, x): + output = self.fc1(x) + output = self.relu(output) # instead of Heaviside step fn + output = self.fc2(output) + return output.ravel() + + +u = Potential().double() +v = Potential().double() + +reg = 1 + +optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005) + +# number of iteration +n_iter = 1000 +n_batch = 500 + + +losses = [] + +for i in range(n_iter): + + # generate noise samples + + iperms = torch.randint(0, n_source_samples, (n_batch,)) + ipermt = torch.randint(0, n_target_samples, (n_batch,)) + + xsi = xs[iperms] + xti = xt[ipermt] + + # minus because we maximize te dual loss + loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + +pl.figure(2) +pl.plot(losses) +pl.grid() +pl.title('Dual objective (negative)') +pl.xlabel("Iterations") + + +# %% +# Plot the density on arget for a given source sample +# --------------------------------------------------- + + +nv = 100 +xl = np.linspace(ax_bounds[0], ax_bounds[1], nv) +yl = np.linspace(ax_bounds[2], ax_bounds[3], nv) + +XX, YY = np.meshgrid(xl, yl) + +xg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1) + +wxg = np.exp(-((xg[:, 0] - 4)**2 + (xg[:, 1] - 4)**2) / (2 * 2)) +wxg = wxg / np.sum(wxg) + +xg = torch.tensor(xg) +wxg = torch.tensor(wxg) + + +pl.figure(4, (12, 4)) +pl.clf() +pl.subplot(1, 3, 1) + +iv = 2 +Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = Gg.reshape((nv, nv)).detach().numpy() + +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) +pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') +pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample') +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Density of transported source sample') + +pl.subplot(1, 3, 2) + +iv = 3 +Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = Gg.reshape((nv, nv)).detach().numpy() + +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) +pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') +pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample') +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Density of transported source sample') + +pl.subplot(1, 3, 3) + +iv = 6 +Gg = ot.stochastic.plan_dual_entropic(u(xs[iv:iv + 1, :]), v(xg), xs[iv:iv + 1, :], xg, reg=reg, wt=wxg) +Gg = Gg.reshape((nv, nv)).detach().numpy() + +pl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker='+', zorder=2, alpha=0.05) +pl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker='o', zorder=2, alpha=0.05) +pl.scatter(Xs[iv:iv + 1, 0], Xs[iv:iv + 1, 1], s=100, marker='+', label='Source sample', zorder=2, alpha=1, color='C0') +pl.pcolormesh(XX, YY, Gg, cmap='Greens', label='Density of transported sourec sample') +pl.legend(loc=0) +ax_bounds = pl.axis() +pl.title('Density of transported source sample') diff --git a/ot/stochastic.py b/ot/stochastic.py index 693675f..61be9bb 100644 --- a/ot/stochastic.py +++ b/ot/stochastic.py @@ -4,12 +4,14 @@ Stochastic solvers for regularized OT. """ -# Author: Kilian Fatras +# Authors: Kilian Fatras +# Rémi Flamary # # License: MIT License import numpy as np - +from .utils import dist +from .backend import get_backend ############################################################################## # Optimization toolbox for SEMI - DUAL problems @@ -747,3 +749,239 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1, return pi, log else: return pi + + +################################################################################ +# Losses for stochastic optimization +################################################################################ + +def loss_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the dual loss of the entropic OT as in equation (6)-(7) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + dual_loss : array-like + Dual loss (to maximize) + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + F = -reg * nx.exp((u[:, None] + v[None, :] - M) / reg) + + return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) + + +def plan_dual_entropic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the primal OT plan the entropic OT as in equation (8) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + G : array-like + Primal OT plan + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + H = nx.exp((u[:, None] + v[None, :] - M) / reg) + + return ws[:, None] * H * wt[None, :] + + +def loss_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the dual loss of the quadratic regularized OT as in equation (6)-(7) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + dual_loss : array-like + Dual loss (to maximize) + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + F = -1.0 / (4 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0)**2 + + return nx.sum(u * ws) + nx.sum(v * wt) + nx.sum(ws[:, None] * F * wt[None, :]) + + +def plan_dual_quadratic(u, v, xs, xt, reg=1, ws=None, wt=None, metric='sqeuclidean'): + r""" + Compute the primal OT plan the quadratic regularized OT as in equation (8) of [19] + + This loss is backend compatible and can be used for stochastic optimization + of the dual potentials. It can be used on the full dataset (beware of + memory) or on minibatches. + + + Parameters + ---------- + u : array-like, shape (ns,) + Source dual potential + v : array-like, shape (nt,) + Target dual potential + xs : array-like, shape (ns,d) + Source samples + xt : array-like, shape (ns,d) + Target samples + reg : float + Regularization term > 0 (default=1) + ws : array-like, shape (ns,), optional + Source sample weights (default unif) + wt : array-like, shape (ns,), optional + Target sample weights (default unif) + metric : string, callable + Ground metric for OT (default quadratic). Can be given as a callable + function taking (xs,xt) as parameters. + + Returns + ------- + G : array-like + Primal OT plan + + + References + ---------- + .. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. Large-scale Optimal Transport and Mapping Estimation. International Conference on Learning Representation (2018) + """ + + nx = get_backend(u, v, xs, xt) + + if ws is None: + ws = nx.ones(xs.shape[0], type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones(xt.shape[0], type_as=xt) / xt.shape[0] + + if callable(metric): + M = metric(xs, xt) + else: + M = dist(xs, xt, metric=metric) + + H = 1.0 / (2 * reg) * nx.maximum(u[:, None] + v[None, :] - M, 0.0) + + return ws[:, None] * H * wt[None, :] diff --git a/test/test_stochastic.py b/test/test_stochastic.py index 736df32..2b5c0fb 100644 --- a/test/test_stochastic.py +++ b/test/test_stochastic.py @@ -8,7 +8,8 @@ for descrete and semicontinous measures from the POT library. """ -# Author: Kilian Fatras +# Authors: Kilian Fatras +# Rémi Flamary # # License: MIT License @@ -213,3 +214,115 @@ def test_dual_sgd_sinkhorn(): G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03) np.testing.assert_allclose( G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd + + +def test_loss_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt) + + ot.stochastic.loss_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_entropic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_entropic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_entropic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 + + +def test_loss_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt) + + ot.stochastic.loss_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + +def test_plan_dual_quadratic(nx): + + nx.seed(0) + + xs = nx.randn(50, 2) + xt = nx.randn(40, 2) + 2 + + ws = nx.rand(50) + ws = ws / nx.sum(ws) + + wt = nx.rand(40) + wt = wt / nx.sum(wt) + + u = nx.randn(50) + v = nx.randn(40) + + def metric(x, y): + return -nx.dot(x, y.T) + + G1 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt) + + assert np.all(nx.to_numpy(G1) >= 0) + assert G1.shape[0] == 50 + assert G1.shape[1] == 40 + + G2 = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, ws=ws, wt=wt, metric=metric) + + assert np.all(nx.to_numpy(G2) >= 0) + assert G2.shape[0] == 50 + assert G2.shape[1] == 40 -- cgit v1.2.3 From ad02112d4288f3efdd5bc6fc6e45444313bba871 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 5 Apr 2022 11:57:10 +0200 Subject: [MRG] Update examples in the doc (#359) * add transparent color logo * add transparent color logo * move screenkhorn * move stochastic and install ffmpeg on circleci * try something * add sudo * install ffmpeg before python * cleanup examples * test svg scrapper * add animation for reg path * better example OT sivergence * update ttles and add plots * update free support * proper figure indexes * have less frame sin animation * update readme and release file * add tests for python 3.10 --- .circleci/config.yml | 7 + .github/workflows/build_tests.yml | 6 +- README.md | 10 +- RELEASES.md | 3 +- docs/source/_static/images/logo.png | Bin 5038 -> 4325 bytes docs/source/_static/images/logo.svg | 174 +++++++++---------- docs/source/conf.py | 6 + .../backends/plot_sliced_wass_grad_flow_pytorch.py | 2 + examples/backends/plot_wass1d_torch.py | 8 +- .../barycenters/plot_free_support_barycenter.py | 55 +++--- examples/others/plot_logo.py | 8 +- examples/others/plot_screenkhorn_1D.py | 71 ++++++++ examples/others/plot_stochastic.py | 189 +++++++++++++++++++++ examples/plot_OT_1D.py | 12 +- examples/plot_OT_1D_smooth.py | 6 +- examples/plot_OT_2D_samples.py | 2 +- examples/plot_OT_L1_vs_L2.py | 32 ++-- examples/plot_compute_emd.py | 72 +++++--- examples/plot_optim_OTreg.py | 38 ++++- examples/plot_screenkhorn_1D.py | 71 -------- examples/plot_stochastic.py | 189 --------------------- examples/sliced-wasserstein/README.txt | 2 +- examples/sliced-wasserstein/plot_variance.py | 8 +- examples/unbalanced-partial/plot_UOT_1D.py | 17 +- examples/unbalanced-partial/plot_regpath.py | 88 +++++++++- 25 files changed, 628 insertions(+), 448 deletions(-) create mode 100644 examples/others/plot_screenkhorn_1D.py create mode 100644 examples/others/plot_stochastic.py delete mode 100644 examples/plot_screenkhorn_1D.py delete mode 100644 examples/plot_stochastic.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 39c19fb..77ab45c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -35,6 +35,12 @@ jobs: - data-cache-0 - pip-cache + - run: + name: Install ffmpeg + command: | + sudo apt update + sudo apt install ffmpeg + - run: name: Get Python running command: | @@ -50,6 +56,7 @@ jobs: paths: - ~/.cache/pip + # Look at what we have and fail early if there is some library conflict - run: name: Check installation diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 3c99da8..ce725c6 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -22,7 +22,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v1 @@ -93,7 +93,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v1 @@ -120,7 +120,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v1 diff --git a/README.md b/README.md index ec5d221..0c3bd19 100644 --- a/README.md +++ b/README.md @@ -29,8 +29,11 @@ POT provides the following generic OT solvers (links to examples): * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] -* [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) -* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] +* [Stochastic + solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and + [differentiable losses](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html) for + Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] @@ -119,7 +122,7 @@ Note that for easier access the module is named `ot` instead of `pot`. ### Dependencies -Some sub-modules require additional dependences which are discussed below +Some sub-modules require additional dependencies which are discussed below * **ot.dr** (Wasserstein dimensionality reduction) depends on autograd and pymanopt that can be installed with: @@ -127,7 +130,6 @@ Some sub-modules require additional dependences which are discussed below pip install pymanopt autograd ``` -* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend. ## Examples diff --git a/RELEASES.md b/RELEASES.md index 45336f7..7d458f3 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features +- Update examples in the gallery (PR #359). - Add stochastic loss and OT plan computation for regularized OT and backend examples(PR #360). - Implementation of factored OT with emd and sinkhorn (PR #358). @@ -254,7 +255,7 @@ are coming for the next versions. #### Closed issues -- Add JMLR paper to teh readme ad Mathieu Blondel to the Acknoledgments (PR +- Add JMLR paper to the readme and Mathieu Blondel to the Acknoledgments (PR #231, #232) - Bug in Unbalanced OT example (Issue #127) - Clean Cython output when calling setup.py clean (Issue #122) diff --git a/docs/source/_static/images/logo.png b/docs/source/_static/images/logo.png index 7be5df7..2dd6f65 100644 Binary files a/docs/source/_static/images/logo.png and b/docs/source/_static/images/logo.png differ diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg index 0bf2cb7..39fe900 100644 --- a/docs/source/_static/images/logo.svg +++ b/docs/source/_static/images/logo.svg @@ -1,24 +1,23 @@ - - + - + - 2022-03-17T17:25:30.736761 + 2022-03-30T17:25:32.476826 image/svg+xml - Matplotlib v3.3.3, https://matplotlib.org/ + Matplotlib v3.5.1, https://matplotlib.org/ - + @@ -26,103 +25,104 @@ L 209.7 75.384 L 209.7 0 L 0 0 +L 0 75.384 z -" style="fill:#ffffff;"/> +" style="fill: none"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" clip-path="url(#p367fff45ba)" style="fill: none; stroke: #000000; stroke-opacity: 0.6; stroke-width: 3; stroke-linecap: square"/> - +" style="stroke: #000000"/> - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - +" style="stroke: #000000"/> - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - - + + diff --git a/docs/source/conf.py b/docs/source/conf.py index 60d0bb7..9526518 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,9 +17,15 @@ import os import re try: import sphinx_gallery + except ImportError: print("warning sphinx-gallery not installed") + + + + + # !!!! allow readthedoc compilation try: from unittest.mock import MagicMock diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py index 05b9952..cf5d64d 100644 --- a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -27,6 +27,8 @@ Machine Learning (pp. 4104-4113). PMLR. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + # %% # Loading the data diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py index 0abdd6d..cd8e2fd 100644 --- a/examples/backends/plot_wass1d_torch.py +++ b/examples/backends/plot_wass1d_torch.py @@ -1,9 +1,9 @@ r""" -================================= -Wasserstein 1D with PyTorch -================================= +================================================= +Wasserstein 1D (flow and barycenter) with PyTorch +================================================= -In this small example, we consider the following minization problem: +In this small example, we consider the following minimization problem: .. math:: \mu^* = \min_\mu W(\mu,\nu) diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 2d68a39..226dfeb 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -9,61 +9,62 @@ sum of diracs. """ -# Author: Vivien Seguy +# Authors: Vivien Seguy +# Rémi Flamary # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import numpy as np import matplotlib.pylab as pl import ot -############################################################################## +# %% # Generate data # ------------- -N = 3 +N = 2 d = 2 -measures_locations = [] -measures_weights = [] - -for i in range(N): - n_i = np.random.randint(low=1, high=20) # nb samples +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] +I2 = pl.imread('../../data/duck.png').astype(np.float64)[::4, ::4, 2] - mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean +sz = I2.shape[0] +XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) - A_i = np.random.rand(d, d) - cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix +x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0 +x2 = np.stack((XX[I2 == 0] + 80, -YY[I2 == 0] + 32), 1) * 1.0 +x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0 - x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations - b_i = np.random.uniform(0., 1., (n_i,)) - b_i = b_i / np.sum(b_i) # Dirac weights +measures_locations = [x1, x2] +measures_weights = [ot.unif(x1.shape[0]), ot.unif(x2.shape[0])] - measures_locations.append(x_i) - measures_weights.append(b_i) +pl.figure(1, (12, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.title('Distributions') -############################################################################## +# %% # Compute free support barycenter # ------------------------------- -k = 10 # number of Diracs of the barycenter +k = 200 # number of Diracs of the barycenter X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) - -############################################################################## -# Plot data +# %% +# Plot the barycenter # --------- -pl.figure(1) -for (x_i, b_i) in zip(measures_locations, measures_weights): - color = np.random.randint(low=1, high=10 * N) - pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure') -pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter') +pl.figure(2, (8, 3)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) +pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') pl.title('Data measures and their barycenter') -pl.legend(loc=0) +pl.legend(loc="lower right") pl.show() diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py index afddcad..9414371 100644 --- a/examples/others/plot_logo.py +++ b/examples/others/plot_logo.py @@ -7,8 +7,8 @@ Logo of the POT toolbox In this example we plot the logo of the POT toolbox. -A specificity of this logo is that it is done 100% in Python and generated using -matplotlib using the EMD solver from POT. +This logo is that it is done 100% in Python and generated using +matplotlib and ploting teh solution of the EMD solver from POT. """ @@ -86,8 +86,8 @@ pl.axis('equal') pl.axis('off') # Save logo file -# pl.savefig('logo.svg', dpi=150, bbox_inches='tight') -# pl.savefig('logo.png', dpi=150, bbox_inches='tight') +# pl.savefig('logo.svg', dpi=150, transparent=True, bbox_inches='tight') +# pl.savefig('logo.png', dpi=150, transparent=True, bbox_inches='tight') # %% # Plot the logo (dark background) diff --git a/examples/others/plot_screenkhorn_1D.py b/examples/others/plot_screenkhorn_1D.py new file mode 100644 index 0000000..2023649 --- /dev/null +++ b/examples/others/plot_screenkhorn_1D.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +""" +======================================== +Screened optimal transport (Screenkhorn) +======================================== + +This example illustrates the computation of Screenkhorn [26]. + +[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). +Screening Sinkhorn Algorithm for Regularized Optimal Transport, +Advances in Neural Information Processing Systems 33 (NeurIPS). +""" + +# Author: Mokhtar Z. Alaya +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot.plot +from ot.datasets import make_1D_gauss as gauss +from ot.bregman import screenkhorn + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + +############################################################################## +# Solve Screenkhorn +# ----------------------- + +# Screenkhorn +lambd = 2e-03 # entropy parameter +ns_budget = 30 # budget number of points to be keeped in the source distribution +nt_budget = 30 # budget number of points to be keeped in the target distribution + +G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn') +pl.show() diff --git a/examples/others/plot_stochastic.py b/examples/others/plot_stochastic.py new file mode 100644 index 0000000..3a1ef31 --- /dev/null +++ b/examples/others/plot_stochastic.py @@ -0,0 +1,189 @@ +""" +=================== +Stochastic examples +=================== + +This example is designed to show how to use the stochatic optimization +algorithms for discrete and semi-continuous measures from the POT library. + +[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. +Stochastic Optimization for Large-scale Optimal Transport. +Advances in Neural Information Processing Systems (2016). + +[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & +Blondel, M. Large-scale Optimal Transport and Mapping Estimation. +International Conference on Learning Representation (2018) + +""" + +# Author: Kilian Fatras +# +# License: MIT License + +import matplotlib.pylab as pl +import numpy as np +import ot +import ot.plot + + +############################################################################# +# Compute the Transportation Matrix for the Semi-Dual Problem +# ----------------------------------------------------------- +# +# Discrete case +# ````````````` +# +# Sample two discrete measures for the discrete case and compute their cost +# matrix c. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 1000 + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# Call the "SAG" method to find the transportation matrix in the discrete case + +method = "SAG" +sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, + numItermax) +print(sag_pi) + +############################################################################# +# Semi-Continuous Case +# ```````````````````` +# +# Sample one general measure a, one discrete measures b for the semicontinous +# case, the points where source and target measures are defined and compute the +# cost matrix. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 1000 +log = True + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# Call the "ASGD" method to find the transportation matrix in the semicontinous +# case. + +method = "ASGD" +asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, + numItermax, log=log) +print(log_asgd['alpha'], log_asgd['beta']) +print(asgd_pi) + +############################################################################# +# Compare the results with the Sinkhorn algorithm + +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) +print(sinkhorn_pi) + + +############################################################################## +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SAG + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') +pl.show() + + +############################################################################## +# For ASGD + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') +pl.show() + + +############################################################################## +# For Sinkhorn + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +pl.show() + + +############################################################################# +# Compute the Transportation Matrix for the Dual Problem +# ------------------------------------------------------ +# +# Semi-continuous case +# ```````````````````` +# +# Sample one general measure a, one discrete measures b for the semi-continuous +# case and compute the cost matrix c. + +n_source = 7 +n_target = 4 +reg = 1 +numItermax = 100000 +lr = 0.1 +batch_size = 3 +log = True + +a = ot.utils.unif(n_source) +b = ot.utils.unif(n_target) + +rng = np.random.RandomState(0) +X_source = rng.randn(n_source, 2) +Y_target = rng.randn(n_target, 2) +M = ot.dist(X_source, Y_target) + +############################################################################# +# +# Call the "SGD" dual method to find the transportation matrix in the +# semi-continuous case + +sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, + batch_size, numItermax, + lr, log=log) +print(log_sgd['alpha'], log_sgd['beta']) +print(sgd_dual_pi) + +############################################################################# +# +# Compare the results with the Sinkhorn algorithm +# ``````````````````````````````````````````````` +# +# Call the Sinkhorn algorithm from POT + +sinkhorn_pi = ot.sinkhorn(a, b, M, reg) +print(sinkhorn_pi) + +############################################################################## +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SGD + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') +pl.show() + + +############################################################################## +# For Sinkhorn + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') +pl.show() diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index 15ead96..62f0b7d 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -==================== -1D optimal transport -==================== +====================================== +Optimal Transport for 1D distributions +====================================== This example illustrates the computation of EMD and Sinkhorn transport plans and their visualization. @@ -64,7 +64,11 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') #%% EMD -G0 = ot.emd(a, b, M) +# use fast 1D solver +G0 = ot.emd_1d(x, x, a, b) + +# Equivalent to +# G0 = ot.emd(a, b, M) pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index b07f99f..5415e4f 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -=========================== -1D smooth optimal transport -=========================== +================================ +Smooth optimal transport example +================================ This example illustrates the computation of EMD, Sinkhorn and smooth OT plans and their visualization. diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index c3a7cd8..1d82fb8 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ==================================================== -2D Optimal transport between empirical distributions +Optimal Transport between 2D empirical distributions ==================================================== Illustration of 2D optimal transport between discributions that are weighted diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py index cb94574..cce51f8 100644 --- a/examples/plot_OT_L1_vs_L2.py +++ b/examples/plot_OT_L1_vs_L2.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """ -========================================== -2D Optimal transport for different metrics -========================================== +================================================ +Optimal Transport with different gournd metrics +================================================ -2D OT on empirical distributio with different gound metric. +2D OT on empirical distributio with different ground metric. Stole the figure idea from Fig. 1 and 2 in https://arxiv.org/pdf/1706.07650.pdf @@ -23,7 +23,7 @@ import matplotlib.pylab as pl import ot import ot.plot -############################################################################## +# %% # Dataset 1 : uniform sampling # ---------------------------- @@ -46,7 +46,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean') M2 /= M2.max() # loss matrix -Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean')) +Mp = ot.dist(xs, xt, metric='cityblock') Mp /= Mp.max() # Data @@ -71,7 +71,7 @@ pl.title('Squared Euclidean cost') pl.subplot(1, 3, 3) pl.imshow(Mp, interpolation='nearest') -pl.title('Sqrt Euclidean cost') +pl.title('L1 (cityblock cost') pl.tight_layout() ############################################################################## @@ -109,22 +109,22 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') # pl.legend(loc=0) -pl.title('OT sqrt Euclidean') +pl.title('OT L1 (cityblock)') pl.tight_layout() pl.show() -############################################################################## +# %% # Dataset 2 : Partial circle # -------------------------- -n = 50 # nb samples +n = 20 # nb samples xtot = np.zeros((n + 1, 2)) xtot[:, 0] = np.cos( - (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) + (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) xtot[:, 1] = np.sin( - (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi) + (np.arange(n + 1) + 1.0) * 0.8 / (n + 2) * 2 * np.pi) xs = xtot[:n, :] xt = xtot[1:, :] @@ -140,7 +140,7 @@ M2 = ot.dist(xs, xt, metric='sqeuclidean') M2 /= M2.max() # loss matrix -Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean')) +Mp = ot.dist(xs, xt, metric='cityblock') Mp /= Mp.max() @@ -166,13 +166,13 @@ pl.title('Squared Euclidean cost') pl.subplot(1, 3, 3) pl.imshow(Mp, interpolation='nearest') -pl.title('Sqrt Euclidean cost') +pl.title('L1 (cityblock) cost') pl.tight_layout() ############################################################################## # Dataset 2 : Plot OT Matrices # ----------------------------- - +# #%% EMD G1 = ot.emd(a, b, M1) @@ -204,7 +204,7 @@ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.axis('equal') # pl.legend(loc=0) -pl.title('OT sqrt Euclidean') +pl.title('OT L1 (cityblock)') pl.tight_layout() pl.show() diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 527a847..36cc7da 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """ -================= -Plot multiple EMD -================= +================== +OT distances in 1D +================== -Shows how to compute multiple EMD and Sinkhorn with two different +Shows how to compute multiple Wassersein and Sinkhorn with two different ground metrics and plot their values for different distributions. @@ -14,7 +14,7 @@ ground metrics and plot their values for different distributions. # # License: MIT License -# sphinx_gallery_thumbnail_number = 3 +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl @@ -29,7 +29,7 @@ from ot.datasets import make_1D_gauss as gauss #%% parameters n = 100 # nb bins -n_target = 50 # nb target distributions +n_target = 20 # nb target distributions # bin positions @@ -47,9 +47,9 @@ for i, m in enumerate(lst_m): # loss matrix and normalization M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean') -M /= M.max() +M /= M.max() * 0.1 M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean') -M2 /= M2.max() +M2 /= M2.max() * 0.1 ############################################################################## # Plot data @@ -59,10 +59,12 @@ M2 /= M2.max() pl.figure(1) pl.subplot(2, 1, 1) -pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, a, 'r', label='Source distribution') pl.title('Source distribution') pl.subplot(2, 1, 2) -pl.plot(x, B, label='Target distributions') +for i in range(n_target): + pl.plot(x, B[:, i], 'b', alpha=i / n_target) +pl.plot(x, B[:, -1], 'b', label='Target distributions') pl.title('Target distributions') pl.tight_layout() @@ -73,14 +75,27 @@ pl.tight_layout() #%% Compute and plot distributions and loss matrix -d_emd = ot.emd2(a, B, M) # direct computation of EMD -d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2 - +d_emd = ot.emd2(a, B, M) # direct computation of OT loss +d_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metrixc M2 +d_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)] pl.figure(2) -pl.plot(d_emd, label='Euclidean EMD') -pl.plot(d_emd2, label='Squared Euclidean EMD') -pl.title('EMD distances') +pl.subplot(2, 1, 1) +pl.plot(x, a, 'r', label='Source distribution') +pl.title('Distributions') +for i in range(n_target): + pl.plot(x, B[:, i], 'b', alpha=i / n_target) +pl.plot(x, B[:, -1], 'b', label='Target distributions') +pl.ylim((-.01, 0.13)) +pl.xticks(()) +pl.legend() +pl.subplot(2, 1, 2) +pl.plot(d_emd, label='Euclidean OT') +pl.plot(d_emd2, label='Squared Euclidean OT') +pl.plot(d_tv, label='Total Variation (TV)') +#pl.xlim((-7,23)) +pl.xlabel('Displacement') +pl.title('Divergences') pl.legend() ############################################################################## @@ -88,17 +103,30 @@ pl.legend() # ----------------------------------------- #%% -reg = 1e-2 +reg = 1e-1 d_sinkhorn = ot.sinkhorn2(a, B, M, reg) d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg) -pl.figure(2) +pl.figure(3) pl.clf() -pl.plot(d_emd, label='Euclidean EMD') -pl.plot(d_emd2, label='Squared Euclidean EMD') + +pl.subplot(2, 1, 1) +pl.plot(x, a, 'r', label='Source distribution') +pl.title('Distributions') +for i in range(n_target): + pl.plot(x, B[:, i], 'b', alpha=i / n_target) +pl.plot(x, B[:, -1], 'b', label='Target distributions') +pl.ylim((-.01, 0.13)) +pl.xticks(()) +pl.legend() +pl.subplot(2, 1, 2) +pl.plot(d_emd, label='Euclidean OT') +pl.plot(d_emd2, label='Squared Euclidean OT') pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn') pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn') -pl.title('EMD distances') +pl.plot(d_tv, label='Total Variation (TV)') +#pl.xlim((-7,23)) +pl.xlabel('Displacement') +pl.title('Divergences') pl.legend() - pl.show() diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py index 5eb15bd..7b021d2 100644 --- a/examples/plot_optim_OTreg.py +++ b/examples/plot_optim_OTreg.py @@ -24,7 +24,7 @@ arXiv preprint arXiv:1510.06567. """ -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 5 import numpy as np import matplotlib.pylab as pl @@ -58,7 +58,7 @@ M /= M.max() G0 = ot.emd(a, b, M) -pl.figure(3, figsize=(5, 5)) +pl.figure(1, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') ############################################################################## @@ -80,7 +80,7 @@ reg = 1e-1 Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True) -pl.figure(3) +pl.figure(2) ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg') ############################################################################## @@ -102,7 +102,7 @@ reg = 1e-3 Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True) -pl.figure(4, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg') ############################################################################## @@ -125,6 +125,34 @@ reg2 = 1e-1 Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True) -pl.figure(5, figsize=(5, 5)) +pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg') pl.show() + + +# %% +# Comparison of the OT matrices + +nvisu = 40 + +pl.figure(5, figsize=(10, 4)) + +pl.subplot(2, 2, 1) +pl.imshow(G0[:nvisu, :]) +pl.axis('off') +pl.title('Exact OT') + +pl.subplot(2, 2, 2) +pl.imshow(Gl2[:nvisu, :]) +pl.axis('off') +pl.title('Frobenius reg.') + +pl.subplot(2, 2, 3) +pl.imshow(Ge[:nvisu, :]) +pl.axis('off') +pl.title('Entropic reg.') + +pl.subplot(2, 2, 4) +pl.imshow(Gel2[:nvisu, :]) +pl.axis('off') +pl.title('Entropic + Frobenius reg.') diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py deleted file mode 100644 index 785642a..0000000 --- a/examples/plot_screenkhorn_1D.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -""" -=============================== -1D Screened optimal transport -=============================== - -This example illustrates the computation of Screenkhorn [26]. - -[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). -Screening Sinkhorn Algorithm for Regularized Optimal Transport, -Advances in Neural Information Processing Systems 33 (NeurIPS). -""" - -# Author: Mokhtar Z. Alaya -# -# License: MIT License - -import numpy as np -import matplotlib.pylab as pl -import ot.plot -from ot.datasets import make_1D_gauss as gauss -from ot.bregman import screenkhorn - -############################################################################## -# Generate data -# ------------- - -#%% parameters - -n = 100 # nb bins - -# bin positions -x = np.arange(n, dtype=np.float64) - -# Gaussian distributions -a = gauss(n, m=20, s=5) # m= mean, s= std -b = gauss(n, m=60, s=10) - -# loss matrix -M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) -M /= M.max() - -############################################################################## -# Plot distributions and loss matrix -# ---------------------------------- - -#%% plot the distributions - -pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') -pl.legend() - -# plot distributions and loss matrix - -pl.figure(2, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') - -############################################################################## -# Solve Screenkhorn -# ----------------------- - -# Screenkhorn -lambd = 2e-03 # entropy parameter -ns_budget = 30 # budget number of points to be keeped in the source distribution -nt_budget = 30 # budget number of points to be keeped in the target distribution - -G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn') -pl.show() diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py deleted file mode 100644 index 3a1ef31..0000000 --- a/examples/plot_stochastic.py +++ /dev/null @@ -1,189 +0,0 @@ -""" -=================== -Stochastic examples -=================== - -This example is designed to show how to use the stochatic optimization -algorithms for discrete and semi-continuous measures from the POT library. - -[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. -Stochastic Optimization for Large-scale Optimal Transport. -Advances in Neural Information Processing Systems (2016). - -[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & -Blondel, M. Large-scale Optimal Transport and Mapping Estimation. -International Conference on Learning Representation (2018) - -""" - -# Author: Kilian Fatras -# -# License: MIT License - -import matplotlib.pylab as pl -import numpy as np -import ot -import ot.plot - - -############################################################################# -# Compute the Transportation Matrix for the Semi-Dual Problem -# ----------------------------------------------------------- -# -# Discrete case -# ````````````` -# -# Sample two discrete measures for the discrete case and compute their cost -# matrix c. - -n_source = 7 -n_target = 4 -reg = 1 -numItermax = 1000 - -a = ot.utils.unif(n_source) -b = ot.utils.unif(n_target) - -rng = np.random.RandomState(0) -X_source = rng.randn(n_source, 2) -Y_target = rng.randn(n_target, 2) -M = ot.dist(X_source, Y_target) - -############################################################################# -# Call the "SAG" method to find the transportation matrix in the discrete case - -method = "SAG" -sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax) -print(sag_pi) - -############################################################################# -# Semi-Continuous Case -# ```````````````````` -# -# Sample one general measure a, one discrete measures b for the semicontinous -# case, the points where source and target measures are defined and compute the -# cost matrix. - -n_source = 7 -n_target = 4 -reg = 1 -numItermax = 1000 -log = True - -a = ot.utils.unif(n_source) -b = ot.utils.unif(n_target) - -rng = np.random.RandomState(0) -X_source = rng.randn(n_source, 2) -Y_target = rng.randn(n_target, 2) -M = ot.dist(X_source, Y_target) - -############################################################################# -# Call the "ASGD" method to find the transportation matrix in the semicontinous -# case. - -method = "ASGD" -asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, - numItermax, log=log) -print(log_asgd['alpha'], log_asgd['beta']) -print(asgd_pi) - -############################################################################# -# Compare the results with the Sinkhorn algorithm - -sinkhorn_pi = ot.sinkhorn(a, b, M, reg) -print(sinkhorn_pi) - - -############################################################################## -# Plot Transportation Matrices -# ```````````````````````````` -# -# For SAG - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') -pl.show() - - -############################################################################## -# For ASGD - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') -pl.show() - - -############################################################################## -# For Sinkhorn - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') -pl.show() - - -############################################################################# -# Compute the Transportation Matrix for the Dual Problem -# ------------------------------------------------------ -# -# Semi-continuous case -# ```````````````````` -# -# Sample one general measure a, one discrete measures b for the semi-continuous -# case and compute the cost matrix c. - -n_source = 7 -n_target = 4 -reg = 1 -numItermax = 100000 -lr = 0.1 -batch_size = 3 -log = True - -a = ot.utils.unif(n_source) -b = ot.utils.unif(n_target) - -rng = np.random.RandomState(0) -X_source = rng.randn(n_source, 2) -Y_target = rng.randn(n_target, 2) -M = ot.dist(X_source, Y_target) - -############################################################################# -# -# Call the "SGD" dual method to find the transportation matrix in the -# semi-continuous case - -sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, - batch_size, numItermax, - lr, log=log) -print(log_sgd['alpha'], log_sgd['beta']) -print(sgd_dual_pi) - -############################################################################# -# -# Compare the results with the Sinkhorn algorithm -# ``````````````````````````````````````````````` -# -# Call the Sinkhorn algorithm from POT - -sinkhorn_pi = ot.sinkhorn(a, b, M, reg) -print(sinkhorn_pi) - -############################################################################## -# Plot Transportation Matrices -# ```````````````````````````` -# -# For SGD - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') -pl.show() - - -############################################################################## -# For Sinkhorn - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') -pl.show() diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt index a575345..73e6122 100644 --- a/examples/sliced-wasserstein/README.txt +++ b/examples/sliced-wasserstein/README.txt @@ -1,4 +1,4 @@ Sliced Wasserstein Distance ---------------------------- \ No newline at end of file +--------------------------- diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 7d73907..f12b522 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -============================== -2D Sliced Wasserstein Distance -============================== +=============================================== +Sliced Wasserstein Distance on 2D distributions +=============================================== This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31]. @@ -16,6 +16,8 @@ measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import matplotlib.pylab as pl import numpy as np diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 183849c..06dd02d 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -12,6 +12,8 @@ using a Kullback-Leibler relaxation. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import numpy as np import matplotlib.pylab as pl import ot @@ -69,7 +71,20 @@ epsilon = 0.1 # entropy parameter alpha = 1. # Unbalanced KL relaxation parameter Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) -pl.figure(4, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn') pl.show() + + +# %% +# plot the transported mass +# ------------------------- + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.fill(x, Gs.sum(1), 'b', alpha=0.5, label='Transported source') +pl.fill(x, Gs.sum(0), 'r', alpha=0.5, label='Transported target') +pl.legend(loc='upper right') +pl.title('Distributions and transported mass for UOT') diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py index 4a51c2d..782e8c2 100644 --- a/examples/unbalanced-partial/plot_regpath.py +++ b/examples/unbalanced-partial/plot_regpath.py @@ -15,11 +15,12 @@ penalized linear regression. # Author: Haoran Wu # License: MIT License +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl import ot - +import matplotlib.animation as animation ############################################################################## # Generate data # ------------- @@ -72,6 +73,9 @@ t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma, ############################################################################## # Plot the regularization path # ---------------- +# +# The OT plan is ploted as a function of $\gamma$ that is the inverse of the +# weight on the marginal relaxations. #%% fully relaxed l2-penalized UOT @@ -103,13 +107,53 @@ for p in range(4): pl.show() +# %% +# Animation of the regpath for UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(-.5, -2.5, nv) + +pl.figure(3) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list, + t_list) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'$\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) + + ############################################################################## # Plot the semi-relaxed regularization path # ------------------- #%% semi-relaxed l2-penalized UOT -pl.figure(3) +pl.figure(4) selected_gamma = [10, 1, 1e-1, 1e-2] for p in range(4): tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, @@ -133,3 +177,43 @@ for p in range(4): if p < 2: pl.xticks(()) pl.show() + + +# %% +# Animation of the regpath for semi-relaxed UOT l2 +# ------------------------ + +nv = 100 +g_list_v = np.logspace(2.5, -2, nv) + +pl.figure(5) + + +def _update_plot(iv): + pl.clf() + tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2, + t_list2) + P = tp.reshape((n, n)) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.5) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 4, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 4, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'Semi-relaxed $\ell_2$ UOT $\gamma$={:1.3f}'.format(g_list_v[iv]), + fontsize=11) + return 1 + + +i = 0 +_update_plot(i) + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000) -- cgit v1.2.3 From 0b223ff883fd73601984a92c31cb70d4aded16e8 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 7 Apr 2022 14:18:54 +0200 Subject: [MRG] Remove deprecated ot.gpu submodule (#361) * remove all cpu submodule and tests * speedup tests gromov --- README.md | 2 +- RELEASES.md | 1 + docs/source/quickstart.rst | 58 +++++++++++--- ot/gpu/__init__.py | 50 ------------ ot/gpu/bregman.py | 196 --------------------------------------------- ot/gpu/da.py | 144 --------------------------------- ot/gpu/utils.py | 101 ----------------------- test/test_gpu.py | 106 ------------------------ test/test_gromov.py | 129 ++++++++++++++--------------- 9 files changed, 113 insertions(+), 674 deletions(-) delete mode 100644 ot/gpu/__init__.py delete mode 100644 ot/gpu/bregman.py delete mode 100644 ot/gpu/da.py delete mode 100644 ot/gpu/utils.py delete mode 100644 test/test_gpu.py diff --git a/README.md b/README.md index 0c3bd19..2ace69c 100644 --- a/README.md +++ b/README.md @@ -185,7 +185,7 @@ The contributors to this library are * [Alexandre Gramfort](http://alexandre.gramfort.net/) (CI, documentation) * [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/) (Partial OT) * [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation) -* [Léo Gautheron](https://github.com/aje) (GPU implementation) +* [Léo Gautheron](https://github.com/aje) (Initial GPU implementation) * [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1) (DA classes) * [Stanislas Chambon](https://slasnista.github.io/) (DA classes) * [Antoine Rolet](https://arolet.github.io/) (EMD solver debug) diff --git a/RELEASES.md b/RELEASES.md index 7d458f3..b54a84a 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### New features +- remode deprecated `ot.gpu` submodule (PR #361) - Update examples in the gallery (PR #359). - Add stochastic loss and OT plan computation for regularized OT and backend examples(PR #360). diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 09a362b..b4cc8ab 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -1028,15 +1028,6 @@ FAQ speedup can be obtained by using a GPU implementation since all operations are matrix/vector products. -4. **Using GPU fails with error: module 'ot' has no attribute 'gpu'** - - In order to limit import time and hard dependencies in POT. we do not import - some sub-modules automatically with :code:`import ot`. In order to use the - acceleration in :any:`ot.gpu` you need first to import is with - :code:`import ot.gpu`. - - See `Issue #85 `__ and :any:`ot.gpu` - for more details. References @@ -1172,3 +1163,52 @@ References .. [30] Flamary, Rémi, et al. "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching." NIPS Workshop on Optimal Transport and Machine Learning OTML. 2014. + +.. [31] Bonneel, Nicolas, et al. `Sliced and radon wasserstein barycenters of + measures + `_\ + , Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +.. [32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance `_\ , Proceedings of the 38th International Conference on Machine Learning (ICML). + +.. [33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov Wasserstein + `_\ , Machine + Learning Journal (MJL), 2021 + +.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & + Peyré, G. (2019, April). `Interpolating between optimal transport and MMD + using Sinkhorn divergences + `_. In The 22nd + International Conference on Artificial Intelligence and Statistics (pp. + 2681-2690). PMLR. + +.. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., + & Schwing, A. G. (2019). `Max-sliced wasserstein distance and its use + for gans + `_. + In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + +.. [36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. + (2019, May). `Sliced-Wasserstein flows: Nonparametric generative modeling via + optimal transport and diffusions + `_. In International + Conference on Machine Learning (pp. 4104-4113). PMLR. + +.. [37] Janati, H., Cuturi, M., Gramfort, A. `Debiased sinkhorn barycenters + `_ Proceedings of + the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + +.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, `Online + Graph Dictionary Learning `_\ , + International Conference on Machine Learning (ICML), 2021. + +.. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). + `Kantorovich duality for general transport costs and applications + `_. + Journal of Functional Analysis, 273(11), 3327-3405. + +.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & + Weed, J. (2019, April). `Statistical optimal transport via factored + couplings `_. In + The 22nd International Conference on Artificial Intelligence and Statistics + (pp. 2454-2465). PMLR. diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py deleted file mode 100644 index 12db605..0000000 --- a/ot/gpu/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- -""" -GPU implementation for several OT solvers and utility -functions. - -The GPU backend in handled by `cupy -`_. - -.. warning:: - This module is now deprecated and will be removed in future releases. POT - now privides a backend mechanism that allows for solving prolem on GPU wth - the pytorch backend. - - -.. warning:: - Note that by default the module is not imported in :mod:`ot`. In order to - use it you need to explicitely import :mod:`ot.gpu` . - -By default, the functions in this module accept and return numpy arrays -in order to proide drop-in replacement for the other POT function but -the transfer between CPU en GPU comes with a significant overhead. - -In order to get the best performances, we recommend to give only cupy -arrays to the functions and desactivate the conversion to numpy of the -result of the function with parameter ``to_numpy=False``. - -""" - -# Author: Remi Flamary -# Leo Gautheron -# -# License: MIT License - -import warnings - -from . import bregman -from . import da -from .bregman import sinkhorn -from .da import sinkhorn_lpl1_mm - -from . import utils -from .utils import dist, to_gpu, to_np - - -warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning) - - -__all__ = ["utils", "dist", "sinkhorn", - "sinkhorn_lpl1_mm", 'bregman', 'da', 'to_gpu', 'to_np'] - diff --git a/ot/gpu/bregman.py b/ot/gpu/bregman.py deleted file mode 100644 index 76af00e..0000000 --- a/ot/gpu/bregman.py +++ /dev/null @@ -1,196 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Bregman projections for regularized OT with GPU -""" - -# Author: Remi Flamary -# Leo Gautheron -# -# License: MIT License - -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations -from . import utils - - -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, - verbose=False, log=False, to_numpy=True, **kwargs): - r""" - Solve the entropic regularization optimal transport on GPU - - If the input matrix are in numpy format, they will be uploaded to the - GPU first which can incur significant time overhead. - - The function solves the following optimization problem: - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : - - - M is the (ns,nt) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - a and b are source and target weights (sum to 1) - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) - samples in the target domain, compute sinkhorn with multiple targets - and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term >0 - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - to_numpy : boolean, optional (default True) - If true convert back the GPU array result to numpy format. - - - Returns - ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - - a = cp.asarray(a) - b = cp.asarray(b) - M = cp.asarray(M) - - if len(a) == 0: - a = np.ones((M.shape[0],)) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],)) / M.shape[1] - - # init data - Nini = len(a) - Nfin = len(b) - - if len(b.shape) > 1: - nbb = b.shape[1] - else: - nbb = 0 - - if log: - log = {'err': []} - - # we assume that no distances are null except those of the diagonal of - # distances - if nbb: - u = np.ones((Nini, nbb)) / Nini - v = np.ones((Nfin, nbb)) / Nfin - else: - u = np.ones(Nini) / Nini - v = np.ones(Nfin) / Nfin - - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) - - Kp = (1 / a).reshape(-1, 1) * K - cpt = 0 - err = 1 - while (err > stopThr and cpt < numItermax): - uprev = u - vprev = v - - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) - - if (np.any(KtransposeU == 0) or - np.any(np.isnan(u)) or np.any(np.isnan(v)) or - np.any(np.isinf(u)) or np.any(np.isinf(v))): - # we have reached the machine precision - # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) - u = uprev - v = vprev - break - if cpt % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - if nbb: - err = np.sqrt( - np.sum((u - uprev)**2) / np.sum((u)**2) - + np.sum((v - vprev)**2) / np.sum((v)**2) - ) - else: - # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - tmp2 = np.sum(u[:, None] * K * v[None, :], 0) - #tmp2=np.einsum('i,ij,j->j', u, K, v) - err = np.linalg.norm(tmp2 - b) # violation of marginal - if log: - log['err'].append(err) - - if verbose: - if cpt % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 - if log: - log['u'] = u - log['v'] = v - - if nbb: # return only loss - #res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) (explodes cupy memory) - res = np.empty(nbb) - for i in range(nbb): - res[i] = np.sum(u[:, None, i] * (K * M) * v[None, :, i]) - if to_numpy: - res = utils.to_np(res) - if log: - return res, log - else: - return res - - else: # return OT matrix - res = u.reshape((-1, 1)) * K * v.reshape((1, -1)) - if to_numpy: - res = utils.to_np(res) - if log: - return res, log - else: - return res - - -# define sinkhorn as sinkhorn_knopp -sinkhorn = sinkhorn_knopp diff --git a/ot/gpu/da.py b/ot/gpu/da.py deleted file mode 100644 index 7adb830..0000000 --- a/ot/gpu/da.py +++ /dev/null @@ -1,144 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Domain adaptation with optimal transport with GPU implementation -""" - -# Author: Remi Flamary -# Nicolas Courty -# Michael Perrot -# Leo Gautheron -# -# License: MIT License - - -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations -import numpy as npp -from . import utils - -from .bregman import sinkhorn - - -def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, - numInnerItermax=200, stopInnerThr=1e-9, verbose=False, - log=False, to_numpy=True): - """ - Solve the entropic regularization optimal transport problem with nonconvex - group lasso regularization on GPU - - If the input matrix are in numpy format, they will be uploaded to the - GPU first which can incur significant time overhead. - - - The function solves the following optimization problem: - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma) - + \eta \Omega_g(\gamma) - - s.t. \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - where : - - - M is the (ns,nt) metric cost matrix - - :math:`\Omega_e` is the entropic regularization term - :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\Omega_g` is the group lasso regulaization term - :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` - where :math:`\mathcal{I}_c` are the index of samples from class c - in the source domain. - - a and b are source and target weights (sum to 1) - - The algorithm used for solving the problem is the generalised conditional - gradient as proposed in [5]_ [7]_ - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - labels_a : np.ndarray (ns,) - labels of samples in the source domain - b : np.ndarray (nt,) - samples weights in the target domain - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term for entropic regularization >0 - eta : float, optional - Regularization term for group lasso regularization >0 - numItermax : int, optional - Max number of iterations - numInnerItermax : int, optional - Max number of iterations (inner sinkhorn solver) - stopInnerThr : float, optional - Stop threshold on error (inner sinkhorn solver) (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - to_numpy : boolean, optional (default True) - If true convert back the GPU array result to numpy format. - - - Returns - ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, - "Optimal Transport for Domain Adaptation," in IEEE - Transactions on Pattern Analysis and Machine Intelligence , - vol.PP, no.99, pp.1-1 - .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). - Generalized conditional gradient: analysis of convergence - and applications. arXiv preprint arXiv:1510.06567. - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT - - """ - - a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M) - - p = 0.5 - epsilon = 1e-3 - - indices_labels = [] - labels_a2 = cp.asnumpy(labels_a) - classes = npp.unique(labels_a2) - for c in classes: - idxc = utils.to_gpu(*npp.where(labels_a2 == c)) - indices_labels.append(idxc) - - W = np.zeros(M.shape) - - for cpt in range(numItermax): - Mreg = M + eta * W - transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, - stopThr=stopInnerThr, to_numpy=False) - # the transport has been computed. Check if classes are really - # separated - W = np.ones(M.shape) - for (i, c) in enumerate(classes): - - majs = np.sum(transp[indices_labels[i]], axis=0) - majs = p * ((majs + epsilon)**(p - 1)) - W[indices_labels[i]] = majs - - if to_numpy: - return utils.to_np(transp) - else: - return transp diff --git a/ot/gpu/utils.py b/ot/gpu/utils.py deleted file mode 100644 index 41e168a..0000000 --- a/ot/gpu/utils.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Utility functions for GPU -""" - -# Author: Remi Flamary -# Nicolas Courty -# Leo Gautheron -# -# License: MIT License - -import cupy as np # np used for matrix computation -import cupy as cp # cp used for cupy specific operations - - -def euclidean_distances(a, b, squared=False, to_numpy=True): - """ - Compute the pairwise euclidean distance between matrices a and b. - - If the input matrix are in numpy format, they will be uploaded to the - GPU first which can incur significant time overhead. - - Parameters - ---------- - a : np.ndarray (n, f) - first matrix - b : np.ndarray (m, f) - second matrix - to_numpy : boolean, optional (default True) - If true convert back the GPU array result to numpy format. - squared : boolean, optional (default False) - if True, return squared euclidean distance matrix - - Returns - ------- - c : (n x m) np.ndarray or cupy.ndarray - pairwise euclidean distance distance matrix - """ - - a, b = to_gpu(a, b) - - a2 = np.sum(np.square(a), 1) - b2 = np.sum(np.square(b), 1) - - c = -2 * np.dot(a, b.T) - c += a2[:, None] - c += b2[None, :] - - if not squared: - np.sqrt(c, out=c) - if to_numpy: - return to_np(c) - else: - return c - - -def dist(x1, x2=None, metric='sqeuclidean', to_numpy=True): - """Compute distance between samples in x1 and x2 on gpu - - Parameters - ---------- - - x1 : np.array (n1,d) - matrix with n1 samples of size d - x2 : np.array (n2,d), optional - matrix with n2 samples of size d (if None then x2=x1) - metric : str - Metric from 'sqeuclidean', 'euclidean', - - - Returns - ------- - - M : np.array (n1,n2) - distance matrix computed with given metric - - """ - if x2 is None: - x2 = x1 - if metric == "sqeuclidean": - return euclidean_distances(x1, x2, squared=True, to_numpy=to_numpy) - elif metric == "euclidean": - return euclidean_distances(x1, x2, squared=False, to_numpy=to_numpy) - else: - raise NotImplementedError - - -def to_gpu(*args): - """ Upload numpy arrays to GPU and return them""" - if len(args) > 1: - return (cp.asarray(x) for x in args) - else: - return cp.asarray(args[0]) - - -def to_np(*args): - """ convert GPU arras to numpy and return them""" - if len(args) > 1: - return (cp.asnumpy(x) for x in args) - else: - return cp.asnumpy(args[0]) diff --git a/test/test_gpu.py b/test/test_gpu.py deleted file mode 100644 index 8e62a74..0000000 --- a/test/test_gpu.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Tests for module gpu for gpu acceleration """ - -# Author: Remi Flamary -# -# License: MIT License - -import numpy as np -import ot -import pytest - -try: # test if cudamat installed - import ot.gpu - nogpu = False -except ImportError: - nogpu = True - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_old_doctests(): - a = [.5, .5] - b = [.5, .5] - M = [[0., 1.], [1., 0.]] - G = ot.sinkhorn(a, b, M, 1) - np.testing.assert_allclose(G, np.array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]])) - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_dist(): - - rng = np.random.RandomState(0) - - for n_samples in [50, 100, 500, 1000]: - print(n_samples) - a = rng.rand(n_samples // 4, 100) - b = rng.rand(n_samples, 100) - - M = ot.dist(a.copy(), b.copy()) - M2 = ot.gpu.dist(a.copy(), b.copy()) - - np.testing.assert_allclose(M, M2, rtol=1e-10) - - M2 = ot.gpu.dist(a.copy(), b.copy(), metric='euclidean', to_numpy=False) - - # check raise not implemented wrong metric - with pytest.raises(NotImplementedError): - M2 = ot.gpu.dist(a.copy(), b.copy(), metric='cityblock', to_numpy=False) - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_sinkhorn(): - - rng = np.random.RandomState(0) - - for n_samples in [50, 100, 500, 1000]: - a = rng.rand(n_samples // 4, 100) - b = rng.rand(n_samples, 100) - - wa = ot.unif(n_samples // 4) - wb = ot.unif(n_samples) - - wb2 = np.random.rand(n_samples, 20) - wb2 /= wb2.sum(0, keepdims=True) - - M = ot.dist(a.copy(), b.copy()) - M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False) - - reg = 1 - - G = ot.sinkhorn(wa, wb, M, reg) - G1 = ot.gpu.sinkhorn(wa, wb, M, reg) - - np.testing.assert_allclose(G1, G, rtol=1e-10) - - # run all on gpu - ot.gpu.sinkhorn(wa, wb, M2, reg, to_numpy=False, log=True) - - # run sinkhorn for multiple targets - ot.gpu.sinkhorn(wa, wb2, M2, reg, to_numpy=False, log=True) - - -@pytest.mark.skipif(nogpu, reason="No GPU available") -def test_gpu_sinkhorn_lpl1(): - - rng = np.random.RandomState(0) - - for n_samples in [50, 100, 500]: - print(n_samples) - a = rng.rand(n_samples // 4, 100) - labels_a = np.random.randint(10, size=(n_samples // 4)) - b = rng.rand(n_samples, 100) - - wa = ot.unif(n_samples // 4) - wb = ot.unif(n_samples) - - M = ot.dist(a.copy(), b.copy()) - M2 = ot.gpu.dist(a.copy(), b.copy(), to_numpy=False) - - reg = 1 - - G = ot.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg) - G1 = ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M, reg) - - np.testing.assert_allclose(G1, G, rtol=1e-10) - - ot.gpu.da.sinkhorn_lpl1_mm(wa, labels_a, wb, M2, reg, to_numpy=False, log=True) diff --git a/test/test_gromov.py b/test/test_gromov.py index 12fd2b9..9c85b92 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -188,7 +188,7 @@ def test_gromov2_gradients(): @pytest.skip_backend("jax", reason="test very slow with jax backend") @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov(nx): - n_samples = 50 # nb samples + n_samples = 10 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -222,9 +222,9 @@ def test_entropic_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov gw, log = ot.gromov.entropic_gromov_wasserstein2( - C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True) + C1, C2, p, q, 'kl_loss', max_iter=10, epsilon=1e-2, log=True) gwb, logb = ot.gromov.entropic_gromov_wasserstein2( - C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True) + C1b, C2b, pb, qb, 'kl_loss', max_iter=10, epsilon=1e-2, log=True) gwb = nx.to_numpy(gwb) G = log['T'] @@ -245,7 +245,7 @@ def test_entropic_gromov(nx): @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_gromov_dtype_device(nx): # setup - n_samples = 50 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -280,7 +280,7 @@ def test_entropic_gromov_dtype_device(nx): def test_pointwise_gromov(nx): - n_samples = 50 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -331,14 +331,12 @@ def test_pointwise_gromov(nx): Gb = nx.to_numpy(nx.todense(Gb)) np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.10342276348494964, atol=1e-8) - np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8) @pytest.skip_backend("tf", reason="test very slow with tf backend") @pytest.skip_backend("jax", reason="test very slow with jax backend") def test_sampled_gromov(nx): - n_samples = 50 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0], dtype=np.float64) cov_s = np.array([[1, 0], [0, 1]], dtype=np.float64) @@ -365,9 +363,9 @@ def test_sampled_gromov(nx): return nx.abs(x - y) G, log = ot.gromov.sampled_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + C1, C2, p, q, loss, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) Gb, logb = ot.gromov.sampled_gromov_wasserstein( - C1b, C2b, pb, qb, lossb, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + C1b, C2b, pb, qb, lossb, max_iter=20, nb_samples_grad=2, epsilon=1, log=True, verbose=True, random_state=42) Gb = nx.to_numpy(Gb) # check constraints @@ -377,13 +375,10 @@ def test_sampled_gromov(nx): np.testing.assert_allclose( q, Gb.sum(0), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.05679474884977278, atol=1e-08) - np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0005986592106971995, atol=1e-08) - def test_gromov_barycenter(nx): - ns = 10 - nt = 20 + ns = 5 + nt = 8 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -450,8 +445,8 @@ def test_gromov_barycenter(nx): @pytest.mark.filterwarnings("ignore:divide") def test_gromov_entropic_barycenter(nx): - ns = 10 - nt = 20 + ns = 5 + nt = 10 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -517,7 +512,7 @@ def test_gromov_entropic_barycenter(nx): def test_fgw(nx): - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -579,7 +574,7 @@ def test_fgw(nx): def test_fgw2_gradients(): - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -625,8 +620,8 @@ def test_fgw2_gradients(): def test_fgw_barycenter(nx): np.random.seed(42) - ns = 50 - nt = 60 + ns = 10 + nt = 20 Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) @@ -674,7 +669,7 @@ def test_fgw_barycenter(nx): def test_gromov_wasserstein_linear_unmixing(nx): - n = 10 + n = 4 X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) @@ -709,10 +704,10 @@ def test_gromov_wasserstein_linear_unmixing(nx): tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 ) - np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) - np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=5e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=5e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=5e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=5e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06) np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06) @@ -758,7 +753,7 @@ def test_gromov_wasserstein_linear_unmixing(nx): def test_gromov_wasserstein_dictionary_learning(nx): # create dataset composed from 2 structures which are repeated 5 times - shape = 10 + shape = 4 n_samples = 2 n_atoms = 2 projection = 'nonnegative_symmetric' @@ -795,7 +790,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Cs[i], Cdict_init, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) initial_total_reconstruction += reconstruction @@ -803,7 +798,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdict, log = ot.gromov.gromov_wasserstein_dictionary_learning( Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, epochs=epochs, batch_size=2 * n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary without backend @@ -811,7 +806,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Cs[i], Cdict, p=None, q=None, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction += reconstruction @@ -822,7 +817,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdictb, log = ot.gromov.gromov_wasserstein_dictionary_learning( Csb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # Compute reconstruction of samples on learned dictionary @@ -830,7 +825,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Csb[i], Cdictb, p=psb[i], q=qb, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b += reconstruction @@ -846,7 +841,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdict_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( Cs, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -854,7 +849,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Cs[i], Cdict_bis, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_bis += reconstruction @@ -865,7 +860,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdictb_bis, log = ot.gromov.gromov_wasserstein_dictionary_learning( Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=None, epochs=epochs, batch_size=n_samples, learning_rate=1., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -873,7 +868,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Csb[i], Cdictb_bis, p=None, q=None, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b_bis += reconstruction @@ -892,7 +887,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdict_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( Cs, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -900,7 +895,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Cs[i], Cdict_bis2, p=ps[i], q=q, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_bis2 += reconstruction @@ -911,7 +906,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): Cdictb_bis2, log = ot.gromov.gromov_wasserstein_dictionary_learning( Csb, D=n_atoms, nt=shape, ps=psb, q=qb, Cdict_init=Cdictb, epochs=epochs, batch_size=n_samples, learning_rate=10., reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -919,7 +914,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, reconstruction = ot.gromov.gromov_wasserstein_linear_unmixing( Csb[i], Cdictb_bis2, p=psb[i], q=qb, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b_bis2 += reconstruction @@ -929,7 +924,7 @@ def test_gromov_wasserstein_dictionary_learning(nx): def test_fused_gromov_wasserstein_linear_unmixing(nx): - n = 10 + n = 4 X1, y1 = ot.datasets.make_data_classif('3gauss', n, random_state=42) X2, y2 = ot.datasets.make_data_classif('3gauss2', n, random_state=42) F, y = ot.datasets.make_data_classif('3gauss', n, random_state=42) @@ -947,28 +942,28 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) - np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) - np.testing.assert_allclose(unmixing1, [1., 0.], atol=1e-01) - np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=1e-06) - np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01) + np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=4e-06) + np.testing.assert_allclose(unmixing1, [1., 0.], atol=4e-01) + np.testing.assert_allclose(unmixing2, nx.to_numpy(unmixing2b), atol=4e-06) + np.testing.assert_allclose(unmixing2, [0., 1.], atol=4e-01) np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-03) np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03) np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03) @@ -983,22 +978,22 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): unmixing1, C1_emb, Y1_emb, OT, reconstruction1 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C1, F, Cdict, Ydict, p=p, q=p, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing1b, C1b_emb, Y1b_emb, OTb, reconstruction1b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C1b, Fb, Cdictb, Ydictb, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing2, C2_emb, Y2_emb, OT, reconstruction2 = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C2, F, Cdict, Ydict, p=None, q=None, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) unmixing2b, C2b_emb, Y2b_emb, OTb, reconstruction2b = ot.gromov.fused_gromov_wasserstein_linear_unmixing( C2b, Fb, Cdictb, Ydictb, p=pb, q=pb, alpha=0.5, reg=reg, - tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=20, max_iter_inner=200 + tol_outer=10**(-6), tol_inner=10**(-6), max_iter_outer=10, max_iter_inner=50 ) np.testing.assert_allclose(unmixing1, nx.to_numpy(unmixing1b), atol=1e-06) @@ -1018,7 +1013,7 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx): def test_fused_gromov_wasserstein_dictionary_learning(nx): # create dataset composed from 2 structures which are repeated 5 times - shape = 10 + shape = 4 n_samples = 2 n_atoms = 2 projection = 'nonnegative_symmetric' @@ -1060,7 +1055,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Cs[i], Ys[i], Cdict_init, Ydict_init, p=ps[i], q=q, - alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + alpha=alpha, reg=0., tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) initial_total_reconstruction += reconstruction @@ -1069,7 +1064,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdict, Ydict, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict_init, Ydict_init=Ydict_init, epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -1077,7 +1072,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Cs[i], Ys[i], Cdict, Ydict, p=None, q=None, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction += reconstruction # Compare both @@ -1088,7 +1083,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdictb, Ydictb, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdict_initb, Ydict_init=Ydict_initb, epochs=epochs, batch_size=2 * n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -1096,7 +1091,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Csb[i], Ysb[i], Cdictb, Ydictb, p=psb[i], q=qb, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b += reconstruction @@ -1111,7 +1106,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdict_bis, Ydict_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Cs, Ys, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -1119,7 +1114,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Cs[i], Ys[i], Cdict_bis, Ydict_bis, p=ps[i], q=q, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_bis += reconstruction @@ -1130,7 +1125,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdictb_bis, Ydictb_bis, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=None, Ydict_init=None, epochs=epochs, batch_size=n_samples, learning_rate_C=1., learning_rate_Y=1., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=False, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) @@ -1139,7 +1134,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Csb[i], Ysb[i], Cdictb_bis, Ydictb_bis, p=psb[i], q=qb, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b_bis += reconstruction @@ -1156,7 +1151,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdict_bis2, Ydict_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Cs, Ys, D=n_atoms, nt=shape, ps=ps, q=q, Cdict_init=Cdict, Ydict_init=Ydict, epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) # > Compute reconstruction of samples on learned dictionary @@ -1164,7 +1159,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Cs[i], Ys[i], Cdict_bis2, Ydict_bis2, p=ps[i], q=q, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_bis2 += reconstruction @@ -1175,7 +1170,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): Cdictb_bis2, Ydictb_bis2, log = ot.gromov.fused_gromov_wasserstein_dictionary_learning( Csb, Ysb, D=n_atoms, nt=shape, ps=None, q=None, Cdict_init=Cdictb, Ydict_init=Ydictb, epochs=epochs, batch_size=n_samples, learning_rate_C=10., learning_rate_Y=10., alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200, + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50, projection=projection, use_log=use_log, use_adam_optimizer=use_adam_optimizer, verbose=verbose ) @@ -1184,7 +1179,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx): for i in range(n_samples): _, _, _, _, reconstruction = ot.gromov.fused_gromov_wasserstein_linear_unmixing( Csb[i], Ysb[i], Cdictb_bis2, Ydictb_bis2, p=None, q=None, alpha=alpha, reg=0., - tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200 + tol_outer=tol, tol_inner=tol, max_iter_outer=10, max_iter_inner=50 ) total_reconstruction_b_bis2 += reconstruction -- cgit v1.2.3 From ac4cf442735ed4c0d5405ad861eddaa02afd4edd Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Mon, 11 Apr 2022 15:38:18 +0200 Subject: [MRG] MM algorithms for UOT (#362) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bugfix * update refs partial OT * fixes small typos in plot_partial_wass_and_gromov * fix small bugs in partial.py * update README * pep8 bugfix * modif doctest * fix bugtests * update on test_partial and test on the numerical precision on ot/partial * resolve merge pb * Delete partial.py * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * add test mm algo unbalanced OT * update unbalanced: mm algo+plot * update unbalanced: mm algo+plot * update releases.md with new MM UOT algorithms Co-authored-by: Rémi Flamary --- README.md | 6 +- RELEASES.md | 1 + docs/source/all.rst | 1 + examples/unbalanced-partial/plot_unbalanced_OT.py | 116 +++++ ot/partial.py | 84 +++- ot/regpath.py | 545 ++++++++++++++-------- ot/unbalanced.py | 223 +++++++++ test/test_unbalanced.py | 50 ++ 8 files changed, 802 insertions(+), 224 deletions(-) create mode 100644 examples/unbalanced-partial/plot_unbalanced_OT.py diff --git a/README.md b/README.md index 2ace69c..1b50aeb 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ POT provides the following generic OT solvers (links to examples): Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. -* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. +* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. @@ -309,4 +309,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. -[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. \ No newline at end of file +[40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). [Statistical optimal transport via factored couplings](http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. + +[41] Chapel*, L., Flamary*, R., Wu, H., Févotte, C., Gasso, G. (2021). [Unbalanced Optimal Transport through Non-negative Penalized Linear Regression](https://proceedings.neurips.cc/paper/2021/file/c3c617a9b80b3ae1ebd868b0017cc349-Paper.pdf) Advances in Neural Information Processing Systems (NeurIPS), 2020. (Two first co-authors) \ No newline at end of file diff --git a/RELEASES.md b/RELEASES.md index b54a84a..7942a15 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -19,6 +19,7 @@ - Add backend support for Domain Adaptation and Unbalanced solvers (PR #343). - Add (F)GW linear dictionary learning solvers + example (PR #319) - Add links to related PR and Issues in the doc release page (PR #350) +- Add new minimization-maximization algorithms for solving exact Unbalanced OT + example (PR #362) #### Closed issues diff --git a/docs/source/all.rst b/docs/source/all.rst index 3f7d029..1ec6be3 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -26,6 +26,7 @@ API and modules plot stochastic unbalanced + regpath partial sliced weak diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py new file mode 100644 index 0000000..03487e7 --- /dev/null +++ b/examples/unbalanced-partial/plot_unbalanced_OT.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +""" +============================================================== +2D examples of exact and entropic unbalanced optimal transport +============================================================== +This example is designed to show how to compute unbalanced and +partial OT in POT. + +UOT aims at solving the following optimization problem: + + .. math:: + W = \min_{\gamma} <\gamma, \mathbf{M}>_F + + \mathrm{reg}\cdot\Omega(\gamma) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + +where :math:`\mathrm{div}` is a divergence. +When using the entropic UOT, :math:`\mathrm{reg}>0` and :math:`\mathrm{div}` +should be the Kullback-Leibler divergence. +When solving exact UOT, :math:`\mathrm{reg}=0` and :math:`\mathrm{div}` +can be either the Kullback-Leibler or the quadratic divergence. +Using :math:`\ell_1` norm gives the so-called partial OT. +""" + +# Author: Laetitia Chapel +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 40 # nb samples + +mu_s = np.array([-1, -1]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +np.random.seed(0) +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +n_noise = 10 + +xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) +xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) + +n = n + n_noise + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + + +############################################################################## +# Compute entropic kl-regularized UOT, kl- and l2-regularized UOT +# ----------- + +reg = 0.005 +reg_m_kl = 0.05 +reg_m_l2 = 5 +mass = 0.7 + +entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl) +kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl') +l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2') +partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass) + +############################################################################## +# Plot the results +# ---------------- + +pl.figure(2) +transp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot] +title = ["partial OT \n m=" + str(mass), "$\ell_2$-UOT \n $\mathrm{reg_m}$=" + + str(reg_m_l2), "kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl), + "entropic kl-UOT \n $\mathrm{reg_m}$=" + str(reg_m_kl)] + +for p in range(4): + pl.subplot(2, 4, p + 1) + P = transp[p] + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.3) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2) + pl.title(title[p]) + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("mappings") + pl.subplot(2, 4, p + 5) + pl.imshow(P, cmap='jet') + pl.yticks(()) + pl.xticks(()) + if p < 1: + pl.ylabel("transport plans") +pl.show() diff --git a/ot/partial.py b/ot/partial.py index b7093e4..0a9e450 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -7,7 +7,6 @@ Partial OT solvers # License: MIT License import numpy as np - from .lp import emd @@ -29,7 +28,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, \gamma &\geq 0 - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m & + \leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} or equivalently (see Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. @@ -50,7 +50,8 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False, - :math:`\lambda` is the lagrangian cost. Tuning its value allows attaining a given mass to be transported `m` - The formulation of the problem has been proposed in :ref:`[28] ` + The formulation of the problem has been proposed in + :ref:`[28] ` Parameters @@ -261,7 +262,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) M_extended = np.zeros((len(a_extended), len(b_extended))) - M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 + M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 2 M_extended[:len(a), :len(b)] = M gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, @@ -455,7 +456,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in :ref:`[29] ` + The formulation of the problem has been proposed in + :ref:`[29] ` Parameters @@ -469,7 +471,8 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -623,16 +626,19 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, \gamma &\geq 0 - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in :ref:`[29] ` + The formulation of the problem has been proposed in + :ref:`[29] ` Parameters @@ -646,7 +652,8 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, q : ndarray, shape (nt,) Distribution in the target space m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported + (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) nb_dummies : int, optional Number of dummy points to add (avoid instabilities in the EMD solver) G0 : ndarray, shape (ns, nt), optional @@ -728,21 +735,25 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, The function considers the following problem: .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, + \mathbf{M} \rangle_F + \mathrm{reg} \cdot\Omega(\gamma) s.t. \gamma \mathbf{1} &\leq \mathbf{a} \\ \gamma^T \mathbf{1} &\leq \mathbf{b} \\ \gamma &\geq 0 \\ - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} \\ where : - :math:`\mathbf{M}` is the metric cost matrix - - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights - `m` is the amount of mass to be transported - The formulation of the problem has been proposed in :ref:`[3] ` (prop. 5) + The formulation of the problem has been proposed in + :ref:`[3] ` (prop. 5) Parameters @@ -829,12 +840,23 @@ def entropic_partial_wasserstein(a, b, M, reg, m=None, numItermax=1000, np.multiply(K, m / np.sum(K), out=K) err, cpt = 1, 0 + q1 = np.ones(K.shape) + q2 = np.ones(K.shape) + q3 = np.ones(K.shape) while (err > stopThr and cpt < numItermax): Kprev = K + K = K * q1 K1 = np.dot(np.diag(np.minimum(a / np.sum(K, axis=1), dx)), K) + q1 = q1 * Kprev / K1 + K1prev = K1 + K1 = K1 * q2 K2 = np.dot(K1, np.diag(np.minimum(b / np.sum(K1, axis=0), dy))) + q2 = q2 * K1prev / K2 + K2prev = K2 + K2 = K2 * q3 K = K2 * (m / np.sum(K2)) + q3 = q3 * K2prev / K if np.any(np.isnan(K)) or np.any(np.isinf(K)): print('Warning: numerical errors at iteration', cpt) @@ -861,7 +883,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the partial Gromov-Wasserstein transport between + :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: @@ -877,7 +900,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, \gamma^T \mathbf{1} &\leq \mathbf{b} - \mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} + \mathbf{1}^T \gamma^T \mathbf{1} = m + &\leq \min\{\|\mathbf{a}\|_1, \|\mathbf{b}\|_1\} where : @@ -885,10 +909,13 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, - :math:`\mathbf{C_2}` is the metric cost matrix in the target space - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L`: quadratic loss function - - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in :ref:`[12] ` and the partial GW in :ref:`[29] ` + The formulation of the GW problem has been proposed in + :ref:`[12] ` and the + partial GW in :ref:`[29] ` Parameters ---------- @@ -903,7 +930,8 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional @@ -1005,13 +1033,15 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, numItermax=1000, tol=1e-7, log=False, verbose=False): r""" - Returns the partial Gromov-Wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` + Returns the partial Gromov-Wasserstein discrepancy between + :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` The function solves the following optimization problem: .. math:: - GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l})\cdot - \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) + GW = \min_{\gamma} \quad \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, + \mathbf{C_2}_{j,l})\cdot + \gamma_{i,j}\cdot\gamma_{k,l} + \mathrm{reg} \cdot\Omega(\gamma) .. math:: s.t. \ \gamma &\geq 0 @@ -1028,10 +1058,13 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, - :math:`\mathbf{C_2}` is the metric cost matrix in the target space - :math:`\mathbf{p}` and :math:`\mathbf{q}` are the sample weights - `L` : quadratic loss function - - :math:`\Omega` is the entropic regularization term, :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term, + :math:`\Omega=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - `m` is the amount of mass to be transported - The formulation of the GW problem has been proposed in :ref:`[12] ` and the partial GW in :ref:`[29] ` + The formulation of the GW problem has been proposed in + :ref:`[12] ` and the + partial GW in :ref:`[29] ` Parameters @@ -1047,7 +1080,8 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, reg: float entropic regularization parameter m : float, optional - Amount of mass to be transported (default: :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) + Amount of mass to be transported (default: + :math:`\min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}`) G0 : ndarray, shape (ns, nt), optional Initialisation of the transportation matrix numItermax : int, optional diff --git a/ot/regpath.py b/ot/regpath.py index 269937a..e745288 100644 --- a/ot/regpath.py +++ b/ot/regpath.py @@ -11,34 +11,48 @@ import scipy.sparse as sp def recast_ot_as_lasso(a, b, C): - r"""This function recasts the l2-penalized UOT problem as a Lasso problem + r"""This function recasts the l2-penalized UOT problem as a Lasso problem. + + Recall the l2-penalized UOT problem defined in + :ref:`[41] ` - Recall the l2-penalized UOT problem defined in [Chapel et al., 2021] .. math:: - UOT = \min_T + \lambda \|T 1_m - a\|_2^2 + - \lambda \|T^T 1_n - b\|_2^2 + \text{UOT}_{\lambda} = \min_T + \lambda \|T 1_m - + \mathbf{a}\|_2^2 + + \lambda \|T^T 1_n - \mathbf{b}\|_2^2 + s.t. T \geq 0 + where : - - C is the (dim_a, dim_b) metric cost matrix - - :math:`\lambda` is the l2-regularization coefficient - - a and b are source and target distributions - - T is the transport plan to optimize - The problem above can be reformulated to a non-negative penalized + - :math:`C` is the cost matrix + - :math:`\lambda` is the l2-regularization parameter + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \ + distributions + - :math:`T` is the transport plan to optimize + + The problem above can be reformulated as a non-negative penalized linear regression problem, particularly Lasso + .. math:: - UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + \text{UOT2}_{\lambda} = \min_{\mathbf{t}} \gamma \mathbf{c}^T + \mathbf{t} + 0.5 * \|H \mathbf{t} - \mathbf{y}\|_2^2 + s.t. - t \geq 0 + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) - - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] - - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, - see [Chapel et al., 2021] for the design of H. The matrix product H t - computes both the source marginal and the target marginal. - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + + - :math:`\mathbf{c}` is the flattened version of the cost matrix :math:`C` + - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \ + and :math:`\mathbf{b}` + - :math:`H` is a metric matrix, see :ref:`[41] ` for \ + the design of :math:`H`. The matrix product :math:`H\mathbf{t}` \ + computes both the source marginal and the target marginals. + - :math:`\mathbf{t}` is the flattened version of the transport plan \ + :math:`T` + Parameters ---------- a : np.ndarray (dim_a,) @@ -47,14 +61,16 @@ def recast_ot_as_lasso(a, b, C): Histogram of dimension dim_b C : np.ndarray, shape (dim_a, dim_b) Cost matrix + Returns ------- H : np.ndarray (dim_a+dim_b, dim_a*dim_b) - Auxiliary matrix constituted by 0 and 1 + Design matrix that contains only 0 and 1 y : np.ndarray (ns + nt, ) - Concatenation of histogram a and histogram b + Concatenation of histograms :math:`\mathbf{a}` and :math:`\mathbf{b}` c : np.ndarray (ns * nt, ) - Flattened array of cost matrix + Flattened array of the cost matrix + Examples -------- >>> import ot @@ -73,12 +89,12 @@ def recast_ot_as_lasso(a, b, C): >>> c array([16., 25., 28., 16., 40., 36.]) + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ dim_a = np.shape(a)[0] @@ -97,33 +113,47 @@ def recast_ot_as_lasso(a, b, C): def recast_semi_relaxed_as_lasso(a, b, C): - r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem + r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem. .. math:: - semi-relaxed UOT = \min_T + \lambda \|T 1_m - a\|_2^2 + + \text{semi-relaxed UOT} = \min_T + + \lambda \|T 1_m - \mathbf{a}\|_2^2 + s.t. - T^T 1_n = b - t \geq 0 + T^T 1_n = \mathbf{b} + + \mathbf{t} \geq 0 + where : - - C is the (dim_a, dim_b) metric cost matrix - - :math:`\lambda` is the l2-regularization coefficient - - a and b are source and target distributions - - T is the transport plan to optimize + + - :math:`C` is the metric cost matrix + - :math:`\lambda` is the l2-regularization parameter + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are the source and target \ + distributions + - :math:`T` is the transport plan to optimize The problem above can be reformulated as follows + .. math:: - semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + \text{semi-relaxed UOT2} = \min_t \gamma \mathbf{c}^T t + + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2 + s.t. - H_c t = b - t \geq 0 + H_c \mathbf{t} = \mathbf{b} + + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) - - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - H_r is a (dim_a, dim_a * dim_b) metric matrix, - which computes the sum along the rows of transport plan T - - H_c is a (dim_b, dim_a * dim_b) metric matrix, - which computes the sum along the columns of transport plan T - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + + - :math:`\mathbf{c}` is flattened version of the cost matrix :math:`C` + - :math:`\gamma = 1/\lambda` is the l2-regularization parameter + - :math:`H_r` is a metric matrix which computes the sum along the \ + rows of the transport plan :math:`T` + - :math:`H_c` is a metric matrix which computes the sum along the \ + columns of the transport plan :math:`T` + - :math:`\mathbf{t}` is the flattened version of :math:`T` + Parameters ---------- a : np.ndarray (dim_a,) @@ -132,16 +162,18 @@ def recast_semi_relaxed_as_lasso(a, b, C): Histogram of dimension dim_b C : np.ndarray, shape (dim_a, dim_b) Cost matrix + Returns ------- Hr : np.ndarray (dim_a, dim_a * dim_b) Auxiliary matrix constituted by 0 and 1, which computes - the sum along the rows of transport plan T + the sum along the rows of transport plan :math:`T` Hc : np.ndarray (dim_b, dim_a * dim_b) Auxiliary matrix constituted by 0 and 1, which computes - the sum along the columns of transport plan T + the sum along the columns of transport plan :math:`T` c : np.ndarray (ns * nt, ) - Flattened array of cost matrix + Flattened array of the cost matrix + Examples -------- >>> import ot @@ -179,49 +211,60 @@ def recast_semi_relaxed_as_lasso(a, b, C): def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma): r""" This function computes the next value of gamma if a variable - will be added in next iteration of the regularization path + is added in the next iteration of the regularization path. We look for the largest value of gamma such that the gradient of an inactive variable vanishes + .. math:: - \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i} + \max_{i \in \bar{A}} \frac{\mathbf{h}_i^T(H_A \phi - \mathbf{y})} + {\mathbf{h}_i^T H_A \delta - \mathbf{c}_i} + where : + - A is the current active set - - h_i is the ith column of auxiliary matrix H - - H_A is the sub-matrix constructed by the columns of H - whose indices belong to the active set A - - c_i is the ith element of cost vector c - - y is the concatenation of source and target distribution - - :math:`\phi` is the intercept of the solutions in current iteration - - :math:`\delta` is the slope of the solutions in current iteration + - :math:`\mathbf{h}_i` is the :math:`i` th column of the design \ + matrix :math:`{H}` + - :math:`{H}_A` is the sub-matrix constructed by the columns of \ + :math:`{H}` whose indices belong to the active set A + - :math:`\mathbf{c}_i` is the :math:`i` th element of the cost vector \ + :math:`\mathbf{c}` + - :math:`\mathbf{y}` is the concatenation of the source and target \ + distributions + - :math:`\phi` is the intercept of the solutions at the current iteration + - :math:`\delta` is the slope of the solutions at the current iteration + Parameters ---------- - phi : np.ndarray (|A|, ) - Intercept of the solutions in current iteration (t is piecewise linear) - delta : np.ndarray (|A|, ) - Slope of the solutions in current iteration (t is piecewise linear) + phi : np.ndarray (size(A), ) + Intercept of the solutions at the current iteration + delta : np.ndarray (size(A), ) + Slope of the solutions at the current iteration HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b) - Matrix product of H^T H + Matrix product of :math:`{H}^T {H}` Hty : np.ndarray (dim_a + dim_b, ) - Matrix product of H^T y + Matrix product of :math:`{H}^T \mathbf{y}` c: np.ndarray (dim_a * dim_b, ) - Flattened array of cost matrix C + Flattened array of the cost matrix :math:`{C}` active_index : list Indices of active variables current_gamma : float - Value of regularization coefficient at the start of current iteration + Value of the regularization parameter at the beginning of the current \ + iteration + Returns ------- next_gamma : float Value of gamma if a variable is added to active set in next iteration next_active_index : int Index of variable to be activated + + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ M = (HtH[:, active_index].dot(phi) - Hty) / \ (HtH[:, active_index].dot(delta) - c + 1e-16) @@ -237,56 +280,65 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, By taking the Lagrangian form of the problem, we obtain a similar update as the two-sided relaxed UOT + .. math:: - \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T - \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i} + + \max_{i \in \bar{A}} \frac{\mathbf{h}_{ri}^T(H_{rA} \phi - \mathbf{a}) + + \mathbf{h}_{c i}^T\phi_u}{\mathbf{h}_{r i}^T H_{r A} \delta + \ + \mathbf{h}_{c i} \delta_u - \mathbf{c}_i} + where : + - A is the current active set - - h_{r i} is the ith column of the matrix H_r - - h_{c i} is the ith column of the matrix H_c - - H_{r A} is the sub-matrix constructed by the columns of H_r - whose indices belong to the active set A - - c_i is the ith element of cost vector c - - y is the concatenation of source and target distribution + - :math:`\mathbf{h}_{r i}` is the ith column of the matrix :math:`H_r` + - :math:`\mathbf{h}_{c i}` is the ith column of the matrix :math:`H_c` + - :math:`H_{r A}` is the sub-matrix constructed by the columns of \ + :math:`H_r` whose indices belong to the active set A + - :math:`\mathbf{c}_i` is the :math:`i` th element of cost vector \ + :math:`\mathbf{c}` - :math:`\phi` is the intercept of the solutions in current iteration - :math:`\delta` is the slope of the solutions in current iteration - - :math:`\phi_u` is the intercept of Lagrange parameter in current - iteration - - :math:`\delta_u` is the slope of Lagrange parameter in current iteration + - :math:`\phi_u` is the intercept of Lagrange parameter at the \ + current iteration + - :math:`\delta_u` is the slope of Lagrange parameter at the \ + current iteration + Parameters ---------- - phi : np.ndarray (|A|, ) - Intercept of the solutions in current iteration (t is piecewise linear) - delta : np.ndarray (|A|, ) - Slope of the solutions in current iteration (t is piecewise linear) + phi : np.ndarray (size(A), ) + Intercept of the solutions at the current iteration + delta : np.ndarray (size(A), ) + Slope of the solutions at the current iteration phi_u : np.ndarray (dim_b, ) - Intercept of the Lagrange parameter in current iteration (also linear) + Intercept of the Lagrange parameter at the current iteration delta_u : np.ndarray (dim_b, ) - Slope of the Lagrange parameter in current iteration (also linear) + Slope of the Lagrange parameter at the current iteration HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) - Matrix product of H_r^T H_r + Matrix product of :math:`H_r^T H_r` Hc : np.ndarray (dim_b, dim_a * dim_b) - Matrix that computes the sum along the columns of transport plan T + Matrix that computes the sum along the columns of the transport plan \ + :math:`T` Hra : np.ndarray (dim_a * dim_b, ) - Matrix product of H_r^T a + Matrix product of :math:`H_r^T \mathbf{a}` c: np.ndarray (dim_a * dim_b, ) - Flattened array of cost matrix C + Flattened array of cost matrix :math:`C` active_index : list Indices of active variables current_gamma : float Value of regularization coefficient at the start of current iteration + Returns ------- next_gamma : float Value of gamma if a variable is added to active set in next iteration next_active_index : int Index of variable to be activated + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \ @@ -297,37 +349,48 @@ def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, def compute_next_removal(phi, delta, current_gamma): - r""" This function computes the next value of gamma if a variable - is removed in next iteration of regularization path + r""" This function computes the next gamma value if a variable + is removed at the next iteration of the regularization path. + + We look for the largest value of the regularization parameter such that + an element of the current solution vanishes - We look for the largest value of gamma such that - an element of current solution vanishes .. math:: \max_{j \in A} \frac{\phi_j}{\delta_j} + where : + - A is the current active set - - phi_j is the jth element of the intercept of current solution - - delta_j is the jth elemnt of the slope of current solution + - :math:`\phi_j` is the :math:`j` th element of the intercept of the \ + current solution + - :math:`\delta_j` is the :math:`j` th element of the slope of the \ + current solution + + Parameters ---------- - phi : np.ndarray (|A|, ) - Intercept of the solutions in current iteration (t is piecewise linear) - delta : np.ndarray (|A|, ) - Slope of the solutions in current iteration (t is piecewise linear) + phi : ndarray, shape (size(A), ) + Intercept of the solution at the current iteration + delta : ndarray, shape (size(A), ) + Slope of the solution at the current iteration current_gamma : float - Value of regularization coefficient at the start of current iteration + Value of the regularization parameter at the beginning of the \ + current iteration + Returns ------- next_removal_gamma : float - Value of gamma if a variable is removed in next iteration + Gamma value if a variable is removed at the next iteration next_removal_index : int - Index of the variable to remove in next iteration + Index of the variable to be removed at the next iteration + + + .. _references-regpath: References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ r_candidate = phi / (delta - 1e-16) r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0 @@ -335,56 +398,74 @@ def compute_next_removal(phi, delta, current_gamma): def complement_schur(M_current, b, d, id_pop): - r""" This function computes the inverse of matrix in regularization path - using Schur complement + r""" This function computes the inverse of the design matrix in the \ + regularization path using the Schur complement. Two cases may arise: + + Case 1: one variable is added to the active set + - Two cases may arise: Firstly one variable is added to the active set .. math:: M_{k+1}^{-1} = \begin{bmatrix} - M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\ - - s^{-1} b^T M_{k}^{-1} & s^{-1} + M_{k}^{-1} + s^{-1} M_{k}^{-1} \mathbf{b} \mathbf{b}^T M_{k}^{-1} \ + & - M_{k}^{-1} \mathbf{b} s^{-1} \\ + - s^{-1} \mathbf{b}^T M_{k}^{-1} & s^{-1} \end{bmatrix} + + where : - - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and - :math:`M_k` is the upper left block matrix in Schur formulation - - b is the upper right block matrix in Schur formulation. In our case, - b is reduced to a column vector and b^T is the lower left block matrix - - s is the Schur complement, given by - :math:`s = d - b^T M_{k}^{-1} b` in our case - - Secondly, one variable is removed from the active set + + - :math:`M_k^{-1}` is the inverse of the design matrix :math:`H_A^tH_A` \ + of the previous iteration + - :math:`\mathbf{b}` is the last column of :math:`M_{k}` + - :math:`s` is the Schur complement, given by \ + :math:`s = \mathbf{d} - \mathbf{b}^T M_{k}^{-1} \mathbf{b}` + + Case 2: one variable is removed from the active set. + .. math:: - M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} - + M_{k+1}^{-1} = M^{-1}_{k \backslash q} - \frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}} + where : - - q is the index of column and row to delete - - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix - without qth column and qth row - - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element - - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}` + + - :math:`q` is the index of column and row to delete + - :math:`M^{-1}_{k \backslash q}` is the previous inverse matrix deprived \ + of the :math:`q` th column and :math:`q` th row + - :math:`r_{-q,q}` is the :math:`q` th column of :math:`M^{-1}_{k}` \ + without the :math:`q` th element + - :math:`r_{q, q}` is the element of :math:`q` th column and :math:`q` th \ + row in :math:`M^{-1}_{k}` + + Parameters ---------- - M_current : np.ndarray (|A|-1, |A|-1) - Inverse matrix in previous iteration - b : np.ndarray (|A|-1, ) - Upper right matrix in Schur complement, a column vector in our case + M_current : ndarray, shape (size(A)-1, size(A)-1) + Inverse matrix of :math:`H_A^tH_A` at the previous iteration, with \ + size(A) the size of the active set + b : ndarray, shape (size(A)-1, ) + None for case 2 (removal), last column of :math:`M_{k}` for case 1 \ + (addition) d : float - Lower right matrix in Schur complement, a scalar in our case - id_pop + should be equal to 2 when UOT and 1 for the semi-relaxed OT + id_pop : int Index of the variable to be removed, equal to -1 - if none of the variables is deleted in current iteration + if no variable is deleted at the current iteration + + Returns ------- - M : np.ndarray (|A|, |A|) - Inverse matrix needed in current iteration + M : ndarray, shape (size(A), size(A)) + Inverse matrix of :math:`H_A^tH_A` of the current iteration + + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ + if b is None: b = M_current[id_pop, :] b = np.delete(b, id_pop) @@ -409,33 +490,39 @@ def complement_schur(M_current, b, d, id_pop): def construct_augmented_H(active_index, m, Hc, HrHr): - r""" This function construct an augmented matrix for the first iteration of - semi-relaxed regularization path + r""" This function constructs an augmented matrix for the first iteration + of the semi-relaxed regularization path .. math:: - Augmented_H = + \text{Augmented}_H = \begin{bmatrix} 0 & H_{c A} \\ H_{c A}^T & H_{r A}^T H_{r A} \end{bmatrix} + where : - - H_{r A} is the sub-matrix constructed by the columns of H_r - whose indices belong to the active set A - - H_{c A} is the sub-matrix constructed by the columns of H_c - whose indices belong to the active set A + + - :math:`H_{r A}` is the sub-matrix constructed by the columns of \ + :math:`H_r` whose indices belong to the active set A + - :math:`H_{c A}` is the sub-matrix constructed by the columns of \ + :math:`H_c` whose indices belong to the active set A + + Parameters ---------- active_index : list - Indices of active variables + Indices of the active variables m : int Length of the target distribution Hc : np.ndarray (dim_b, dim_a * dim_b) - Matrix that computes the sum along the columns of transport plan T + Matrix that computes the sum along the columns of the transport plan \ + :math:`T` HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) - Matrix product of H_r^T H_r + Matrix product of :math:`H_r^T H_r` + Returns ------- - H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|) + H_augmented : np.ndarray (dim_b + size(A), dim_b + size(A)) Augmented matrix for the first iteration of the semi-relaxed regularization path """ @@ -451,18 +538,27 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, r"""This function gives the regularization path of l2-penalized UOT problem The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: - \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + \min_t \gamma \mathbf{c}^T \mathbf{t} + + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2 + s.t. - t \geq 0 + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + + - :math:`\mathbf{c}` is the flattened version of the cost matrix \ + :math:`{C}` - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] - - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, - see [Chapel et al., 2021] for the design of H. The matrix product Ht - computes both the source marginal and the target marginal. - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \ + and :math:`\mathbf{b}`, defined as \ + :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]` + - :math:`{H}` is a design matrix, see :ref:`[41] ` \ + for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \ + computes both the source marginal and the target marginals. + - :math:`\mathbf{t}` is the flattened version of the transport matrix + Parameters ---------- a : np.ndarray (dim_a,) @@ -478,11 +574,12 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, Returns ------- t : np.ndarray (dim_a*dim_b, ) - Flattened vector of optimal transport matrix + Flattened vector of the optimal transport matrix t_list : list - List of solutions in regularization path + List of solutions in the regularization path gamma_list : list - List of regularization coefficient in regularization path + List of regularization coefficients in the regularization path + Examples -------- >>> import ot @@ -502,10 +599,9 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ n = np.shape(a)[0] @@ -580,22 +676,32 @@ def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, itmax=50000): r"""This function gives the regularization path of semi-relaxed - l2-UOT problem + l2-UOT problem. The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: - \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + + \min_t \gamma \mathbf{c}^T t + + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2 + s.t. - H_c t = b - t \geq 0 + H_c \mathbf{t} = \mathbf{b} + + \mathbf{t} \geq 0 + where : - - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) - - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient - - H_r is a (dim_a, dim_a * dim_b) metric matrix, - which computes the sum along the rows of transport plan T - - H_c is a (dim_b, dim_a * dim_b) metric matrix, - which computes the sum along the columns of transport plan T - - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + + - :math:`\mathbf{c}` is the flattened version of the cost matrix \ + :math:`C` + - :math:`\gamma = 1/\lambda` is the l2-regularization parameter + - :math:`H_r` is a matrix that computes the sum along the rows of \ + the transport plan :math:`T` + - :math:`H_c` is a matrix that computes the sum along the columns of \ + the transport plan :math:`T` + - :math:`\mathbf{t}` is the flattened version of the transport plan \ + :math:`T` + Parameters ---------- a : np.ndarray (dim_a,) @@ -608,14 +714,16 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, l2-regularization coefficient itmax: int (optional) Maximum number of iteration + Returns ------- t : np.ndarray (dim_a*dim_b, ) - Flattened vector of optimal transport matrix + Flattened vector of the (unregularized) optimal transport matrix t_list : list - List of solutions in regularization path + List of all the optimal transport vectors of the regularization path gamma_list : list - List of regularization coefficient in regularization path + List of the regularization parameters in the path + Examples -------- >>> import ot @@ -635,10 +743,9 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ n = np.shape(a)[0] @@ -722,8 +829,44 @@ def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, semi_relaxed=False, itmax=50000): - r"""This function combines both the semi-relaxed and the fully-relaxed - regularization paths of l2-UOT problem + r"""This function provides all the solutions of the regularization path \ + of the l2-UOT problem :ref:`[41] `. + + The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + + .. math:: + \min_t \gamma \mathbf{c}^T \mathbf{t} + + 0.5 * \|{H} \mathbf{t} - \mathbf{y}\|_2^2 + + s.t. + \mathbf{t} \geq 0 + + where : + + - :math:`\mathbf{c}` is the flattened version of the cost matrix \ + :math:`{C}` + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - :math:`\mathbf{y}` is the concatenation of vectors :math:`\mathbf{a}` \ + and :math:`\mathbf{b}`, defined as \ + :math:`\mathbf{y}^T = [\mathbf{a}^T \mathbf{b}^T]` + - :math:`{H}` is a design matrix, see :ref:`[41] ` \ + for the design of :math:`{H}`. The matrix product :math:`H\mathbf{t}` \ + computes both the source marginal and the target marginals. + - :math:`\mathbf{t}` is the flattened version of the transport matrix + + For the semi-relaxed problem, it optimizes the Lasso reformulation of the + l2-penalized UOT: + + .. math:: + + \min_t \gamma \mathbf{c}^T \mathbf{t} + + 0.5 * \|H_r \mathbf{t} - \mathbf{a}\|_2^2 + + s.t. + H_c \mathbf{t} = \mathbf{b} + + \mathbf{t} \geq 0 + Parameters ---------- @@ -736,23 +879,24 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, reg: float (optional) l2-regularization coefficient semi_relaxed : bool (optional) - Give the semi-relaxed path if true + Give the semi-relaxed path if True itmax: int (optional) Maximum number of iteration + Returns ------- t : np.ndarray (dim_a*dim_b, ) - Flattened vector of optimal transport matrix + Flattened vector of the (unregularized) optimal transport matrix t_list : list - List of solutions in regularization path + List of all the optimal transport vectors of the regularization path gamma_list : list - List of regularization coefficient in regularization path + List of the regularization parameters in the path + References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ if semi_relaxed: t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg, @@ -765,27 +909,33 @@ def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, def compute_transport_plan(gamma, gamma_list, Pi_list): r""" Given the regularization path, this function computes the transport - plan for any value of gamma by the piecewise linearity of the path + plan for any value of gamma thanks to the piecewise linearity of the path. .. math:: t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma) - where : - - :math:`\gamma` is the regularization coefficient + + where: + + - :math:`\gamma` is the regularization parameter - :math:`\phi(\gamma)` is the corresponding intercept - :math:`\delta(\gamma)` is the corresponding slope - - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix) + - :math:`\mathbf{t}` is the flattened version of the transport matrix + Parameters ---------- gamma : float Regularization coefficient gamma_list : list - List of regularization coefficients in regularization path + List of regularization parameters of the regularization path Pi_list : list - List of solutions in regularization path + List of all the solutions of the regularization path + Returns ------- t : np.ndarray (dim_a*dim_b, ) - Transport vector corresponding to the given value of gamma + Vectorization of the transport plan corresponding to the given value + of gamma + Examples -------- >>> import ot @@ -804,12 +954,13 @@ def compute_transport_plan(gamma, gamma_list, Pi_list): array([0. , 0. , 0. , 0.19722222, 0.05555556, 0. , 0. , 0.24722222, 0. ]) + + .. _references-regpath: References ---------- - [Chapel et al., 2021]: - Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized - linear regression. + linear regression. NeurIPS. """ if gamma >= gamma_list[0]: diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 503cc1e..90c920c 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -4,6 +4,7 @@ Regularized Unbalanced OT solvers """ # Author: Hicham Janati +# Laetitia Chapel # License: MIT License from __future__ import division @@ -1029,3 +1030,225 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method) + + +def mm_unbalanced(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, + stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] ` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float + Marginal relaxation term > 0 + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + Returns + ------- + gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'kl'), 2) + array([[0.3 , 0. ], + [0. , 0.07]]) + >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 1, 'l2'), 2) + array([[0.25, 0. ], + [0. , 0. ]]) + + + .. _references-regpath: + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd : Unregularized OT + ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT + """ + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = nx.ones(dim_a, type_as=M) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=M) / dim_b + + if G0 is None: + G = a[:, None] * b[None, :] + else: + G = G0 + + if log: + log = {'err': [], 'G': []} + + if div == 'kl': + K = nx.exp(M / - reg_m / 2) + elif div == 'l2': + K = nx.maximum(a[:, None] + b[None, :] - M / reg_m / 2, + nx.zeros((dim_a, dim_b), type_as=M)) + else: + warnings.warn("The div parameter should be either equal to 'kl' or \ + 'l2': it has been set to 'kl'.") + div = 'kl' + K = nx.exp(M / - reg_m / 2) + + for i in range(numItermax): + Gprev = G + + if div == 'kl': + u = nx.sqrt(a / (nx.sum(G, 1) + 1e-16)) + v = nx.sqrt(b / (nx.sum(G, 0) + 1e-16)) + G = G * K * u[:, None] * v[None, :] + elif div == 'l2': + Gd = nx.sum(G, 0, keepdims=True) + nx.sum(G, 1, keepdims=True) + 1e-16 + G = G * K / Gd + + err = nx.sqrt(nx.sum((G - Gprev) ** 2)) + if log: + log['err'].append(err) + log['G'].append(G) + if verbose: + print('{:5d}|{:8e}|'.format(i, err)) + if err < stopThr: + break + + if log: + log['cost'] = nx.sum(G * M) + return G, log + else: + return G + + +def mm_unbalanced2(a, b, M, reg_m, div='kl', G0=None, numItermax=1000, + stopThr=1e-15, verbose=False, log=False): + r""" + Solve the unbalanced optimal transport problem and return the OT plan. + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_m} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where: + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + unbalanced distributions + - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence + + The algorithm used for solving the problem is a maximization- + minimization algorithm as proposed in :ref:`[41] ` + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) + Unnormalized histogram of dimension `dim_b` + M : array-like (dim_a, dim_b) + loss matrix + reg_m: float + Marginal relaxation term > 0 + div: string, optional + Divergence to quantify the difference between the marginals. + Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) + G0: array-like (dim_a, dim_b) + Initialization of the transport matrix + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + ot_distance : array-like + the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}` + log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + >>> import ot + >>> import numpy as np + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[1., 36.],[9., 4.]] + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'l2'),2) + 0.25 + >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 1, 'kl'),2) + 0.57 + + References + ---------- + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + See Also + -------- + ot.lp.emd2 : Unregularized OT loss + ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss + """ + _, log_mm = mm_unbalanced(a, b, M, reg_m, div=div, G0=G0, + numItermax=numItermax, stopThr=stopThr, + verbose=verbose, log=True) + + if log: + return log_mm['cost'], log_mm + else: + return log_mm['cost'] diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index db59504..02b3fc3 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -1,6 +1,7 @@ """Tests for module Unbalanced OT with entropy regularization""" # Author: Hicham Janati +# Laetitia Chapel # # License: MIT License @@ -286,3 +287,52 @@ def test_implemented_methods(nx): method=method) barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m, method=method) + + +def test_mm_convergence(nx): + n = 100 + rng = np.random.RandomState(42) + x = rng.randn(n, 2) + rng = np.random.RandomState(75) + y = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + + M = ot.dist(x, y) + M = M / M.max() + reg_m = 100 + a, b, M = nx.from_numpy(a, b, M) + + G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', + verbose=True, log=True) + loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2( + a, b, M, reg_m, div='kl', verbose=True)) + G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', + verbose=False, log=True) + + # check if the marginals come close to the true ones when large reg + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03) + np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03) + + # check if mm_unbalanced2 returns the correct loss + np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl, + atol=1e-5) + + # check in case no histogram is provided + a_np, b_np = np.array([]), np.array([]) + a, b = nx.from_numpy(a_np, b_np) + + G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl') + G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2') + np.testing.assert_allclose(G_kl_null, G_kl) + np.testing.assert_allclose(G_l2_null, G_l2) + + # test when G0 is given + G0 = ot.emd(a, b, M) + reg_m = 10000 + G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0) + G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0) + np.testing.assert_allclose(G0, G_kl, atol=1e-05) + np.testing.assert_allclose(G0, G_l2, atol=1e-05) -- cgit v1.2.3 From 486b0d6397182a57cd53651dca87fcea89747490 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Mon, 11 Apr 2022 16:26:30 +0200 Subject: [MRG] Center gradients for mass of emd2 and gw2 (#363) * center gradients for mass of emd2 and gw2 * debug fgw gradient * debug fgw --- RELEASES.md | 4 +++- ot/gromov.py | 7 +++++-- ot/lp/__init__.py | 7 ++++--- test/test_ot.py | 8 +++++++- 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 7942a15..33d1ab6 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,7 +5,7 @@ #### New features -- remode deprecated `ot.gpu` submodule (PR #361) +- Remove deprecated `ot.gpu` submodule (PR #361) - Update examples in the gallery (PR #359). - Add stochastic loss and OT plan computation for regularized OT and backend examples(PR #360). @@ -23,6 +23,8 @@ #### Closed issues +- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are + centered (Issue #364, PR #363) - Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337, PR #338) - Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349) diff --git a/ot/gromov.py b/ot/gromov.py index c5a82d1..55ab0bd 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -551,7 +551,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo= gC1 = nx.from_numpy(gC1, type_as=C10) gC2 = nx.from_numpy(gC2, type_as=C10) gw = nx.set_gradients(gw, (p0, q0, C10, C20), - (log_gw['u'], log_gw['v'], gC1, gC2)) + (log_gw['u'] - nx.mean(log_gw['u']), + log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2)) if log: return gw, log_gw @@ -793,7 +794,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 gC1 = nx.from_numpy(gC1, type_as=C10) gC2 = nx.from_numpy(gC2, type_as=C10) fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), - (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T0)) if log: return fgw_dist, log_fgw diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index abf7fe0..390c32d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -517,7 +517,8 @@ def emd2(a, b, M, processes=1, log['warning'] = result_code_string log['result_code'] = result_code cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (log['u'], log['v'], G)) + (a0, b0, M0), (log['u'] - nx.mean(log['u']), + log['v'] - nx.mean(log['v']), G)) return [cost, log] else: def f(b): @@ -540,8 +541,8 @@ def emd2(a, b, M, processes=1, ) G = nx.from_numpy(G, type_as=type_as) cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), (nx.from_numpy(u, type_as=type_as), - nx.from_numpy(v, type_as=type_as), G)) + (a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), G)) check_result(result_code) return cost diff --git a/test/test_ot.py b/test/test_ot.py index bb258e2..bf832f6 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -147,7 +147,7 @@ def test_emd2_gradients(): b1 = torch.tensor(a, requires_grad=True) M1 = torch.tensor(M, requires_grad=True) - val = ot.emd2(a1, b1, M1) + val, log = ot.emd2(a1, b1, M1, log=True) val.backward() @@ -155,6 +155,12 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape + assert np.allclose(a1.grad.cpu().detach().numpy(), + log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean()) + + assert np.allclose(b1.grad.cpu().detach().numpy(), + log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean()) + # Testing for bug #309, checking for scaling of gradient a2 = torch.tensor(a, requires_grad=True) b2 = torch.tensor(a, requires_grad=True) -- cgit v1.2.3 From eccb1386eea52b94b82456d126bd20cbe3198e05 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 21 Apr 2022 16:34:01 +0200 Subject: [MRG] Release 8.2 (#365) * release text and number * add examples in release fil build wheels * switch gallery to release * add much needed contributors file * debug circleci * une line of logos * working logo * back to stable sphinx galery --- .circleci/config.yml | 5 ++- CONTRIBUTORS.md | 52 ++++++++++++++++++++++++++++ README.md | 35 ++++--------------- RELEASES.md | 57 ++++++++++++++++++++++++++----- docs/source/_static/images/logo_3ia.jpg | Bin 0 -> 25029 bytes docs/source/_static/images/logo_anr.jpg | Bin 0 -> 23493 bytes docs/source/_static/images/logo_cnrs.jpg | Bin 0 -> 6918 bytes docs/source/contributors.rst | 6 ++++ docs/source/index.rst | 1 + docs/source/releases.rst | 2 +- examples/others/plot_logo.py | 10 +++--- ot/__init__.py | 2 +- 12 files changed, 122 insertions(+), 48 deletions(-) create mode 100644 CONTRIBUTORS.md create mode 100644 docs/source/_static/images/logo_3ia.jpg create mode 100644 docs/source/_static/images/logo_anr.jpg create mode 100644 docs/source/_static/images/logo_cnrs.jpg create mode 100644 docs/source/contributors.rst diff --git a/.circleci/config.yml b/.circleci/config.yml index 77ab45c..7e15a65 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -48,9 +48,8 @@ jobs: python -m pip install --user -e . python -m pip install --user --upgrade --no-cache-dir --progress-bar off -r requirements.txt python -m pip install --user --upgrade --progress-bar off -r docs/requirements.txt - python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler - - + python -m pip install --user --upgrade --progress-bar off ipython sphinx-gallery memory_profiler + # python -m pip install --user --upgrade --progress-bar off ipython "https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master" memory_profiler - save_cache: key: pip-cache paths: diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md new file mode 100644 index 0000000..ab64fba --- /dev/null +++ b/CONTRIBUTORS.md @@ -0,0 +1,52 @@ + + +## Creators and Maintainers + +This toolbox has been created and is maintained by: + +* [Rémi Flamary](http://remi.flamary.com/) +* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/) + +## Contributors + +The contributors to this library are: + +* [Rémi Flamary](http://remi.flamary.com/) (EMD wrapper, Pytorch backend, DA + classes, conditional gradients, WDA, weak OT, linear OT mapping, documentation) +* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/) (Original sinkhorn, + Wasserstein barycenters and convolutional barycenters, 1D wasserstein) +* [Alexandre Gramfort](http://alexandre.gramfort.net/) (CI, documentation) +* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/) (Partial OT, + Unbalanced OT non-regularized) +* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation) +* [Léo Gautheron](https://github.com/aje) (Initial GPU implementation) +* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1) (DA classes) +* [Stanislas Chambon](https://slasnista.github.io/) (DA classes) +* [Antoine Rolet](https://arolet.github.io/) (EMD solver debug) +* Erwan Vautier (Gromov-Wasserstein) +* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers, + empirical sinkhorn) +* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) (Greenkhorn) +* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein, Fused-Gromov-Wasserstein) +* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters) +* [Romain Tavenard](https://rtavenar.github.io/) (1D Wasserstein) +* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) +* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) +* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) +* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) +* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) +* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) +* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) + +## Acknowledgments + +This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): + +* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab) +* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT) +* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD) +* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda) + +POT has benefited from the financing or manpower from the following partners: + +ANRCNRS3IA \ No newline at end of file diff --git a/README.md b/README.md index 1b50aeb..e2b33d9 100644 --- a/README.md +++ b/README.md @@ -180,35 +180,12 @@ This toolbox has been created and is maintained by * [Rémi Flamary](http://remi.flamary.com/) * [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/) -The contributors to this library are - -* [Alexandre Gramfort](http://alexandre.gramfort.net/) (CI, documentation) -* [Laetitia Chapel](http://people.irisa.fr/Laetitia.Chapel/) (Partial OT) -* [Michael Perrot](http://perso.univ-st-etienne.fr/pem82055/) (Mapping estimation) -* [Léo Gautheron](https://github.com/aje) (Initial GPU implementation) -* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1) (DA classes) -* [Stanislas Chambon](https://slasnista.github.io/) (DA classes) -* [Antoine Rolet](https://arolet.github.io/) (EMD solver debug) -* Erwan Vautier (Gromov-Wasserstein) -* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers) -* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) -* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein) -* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters) -* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) -* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) -* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) -* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) -* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) -* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) -* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends) -* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning) - -This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): - -* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab) -* [Mathieu Blondel](https://mblondel.org/) (original implementation smooth OT) -* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) (C++ code for EMD) -* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda) +The numerous contributors to this library are listed [here](CONTRIBUTORS.md). + +POT has benefited from the financing or manpower from the following partners: + +ANRCNRS3IA + ## Contributions and code of conduct diff --git a/RELEASES.md b/RELEASES.md index 33d1ab6..be2192e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,22 +1,61 @@ # Releases -## 0.8.2dev Development +## 0.8.2 + +This releases introduces several new notable features. The less important +but most exiting one being that we now have a logo for the toolbox (color +and dark background) : + +![](https://pythonot.github.io/master/_images/logo.svg)![](https://pythonot.github.io/master/_static/logo_dark.svg) + +This logo is generated using with matplotlib and using the solution of an OT +problem provided by POT (with `ot.emd`). Generating the logo can be done with a +simple python script also provided in the [documentation gallery](https://pythonot.github.io/auto_examples/others/plot_logo.html#sphx-glr-auto-examples-others-plot-logo-py). + +New OT solvers include [Weak +OT](https://pythonot.github.io/gen_modules/ot.weak.html#ot.weak.weak_optimal_transport) + and [OT with factored +coupling](https://pythonot.github.io/gen_modules/ot.factored.html#ot.factored.factored_optimal_transport) +that can be used on large datasets. The [Majorization Minimization](https://pythonot.github.io/gen_modules/ot.unbalanced.html?highlight=mm_#ot.unbalanced.mm_unbalanced) solvers for +non-regularized Unbalanced OT are now also available. We also now provide an +implementation of [GW and FGW unmixing](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.gromov_wasserstein_linear_unmixing) and [dictionary learning](https://pythonot.github.io/gen_modules/ot.gromov.html#ot.gromov.gromov_wasserstein_dictionary_learning). It is now +possible to use autodiff to solve entropic an quadratic regularized OT in the +dual for full or stochastic optimization thanks to the new functions to compute +the dual loss for [entropic](https://pythonot.github.io/gen_modules/ot.stochastic.html#ot.stochastic.loss_dual_entropic) and [quadratic](https://pythonot.github.io/gen_modules/ot.stochastic.html#ot.stochastic.loss_dual_quadratic) regularized OT and reconstruct the [OT +plan](https://pythonot.github.io/gen_modules/ot.stochastic.html#ot.stochastic.plan_dual_entropic) on part or all of the data. They can be used for instance to solve OT +problems with stochastic gradient or for estimating the [dual potentials as +neural networks](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html#sphx-glr-auto-examples-backends-plot-stoch-continuous-ot-pytorch-py). + +On the backend front, we now have backend compatible functions and classes in +the domain adaptation [`ot.da`](https://pythonot.github.io/gen_modules/ot.da.html#module-ot.da) and unbalanced OT [`ot.unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html) modules. This +means that the DA classes can be used on tensors from all compatible backends. +The [free support Wasserstein barycenter](https://pythonot.github.io/gen_modules/ot.lp.html?highlight=free%20support#ot.lp.free_support_barycenter) solver is now also backend compatible. + +Finally we have worked on the documentation to provide an update of existing +examples in the gallery and and several new examples including [GW dictionary +learning](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html#sphx-glr-auto-examples-gromov-plot-gromov-wasserstein-dictionary-learning-py) +[weak Optimal +Transport](https://pythonot.github.io/auto_examples/others/plot_WeakOT_VS_OT.html#sphx-glr-auto-examples-others-plot-weakot-vs-ot-py), +[NN based dual potentials +estimation](https://pythonot.github.io/auto_examples/backends/plot_stoch_continuous_ot_pytorch.html#sphx-glr-auto-examples-backends-plot-stoch-continuous-ot-pytorch-py) +and [Factored coupling OT](https://pythonot.github.io/auto_examples/others/plot_factored_coupling.html#sphx-glr-auto-examples-others-plot-factored-coupling-py). +. #### New features - Remove deprecated `ot.gpu` submodule (PR #361) -- Update examples in the gallery (PR #359). +- Update examples in the gallery (PR #359) - Add stochastic loss and OT plan computation for regularized OT and - backend examples(PR #360). -- Implementation of factored OT with emd and sinkhorn (PR #358). + backend examples(PR #360) +- Implementation of factored OT with emd and sinkhorn (PR #358) - A brand new logo for POT (PR #357) -- Better list of related examples in quick start guide with `minigallery` (PR #334). +- Better list of related examples in quick start guide with `minigallery` (PR #334) - Add optional log-domain Sinkhorn implementation in WDA to support smaller values - of the regularization parameter (PR #336). -- Backend implementation for `ot.lp.free_support_barycenter` (PR #340). -- Add weak OT solver + example (PR #341). -- Add backend support for Domain Adaptation and Unbalanced solvers (PR #343). + of the regularization parameter (PR #336) +- Backend implementation for `ot.lp.free_support_barycenter` (PR #340) +- Add weak OT solver + example (PR #341) +- Add backend support for Domain Adaptation and Unbalanced solvers (PR #343) - Add (F)GW linear dictionary learning solvers + example (PR #319) - Add links to related PR and Issues in the doc release page (PR #350) - Add new minimization-maximization algorithms for solving exact Unbalanced OT + example (PR #362) diff --git a/docs/source/_static/images/logo_3ia.jpg b/docs/source/_static/images/logo_3ia.jpg new file mode 100644 index 0000000..ecc56b2 Binary files /dev/null and b/docs/source/_static/images/logo_3ia.jpg differ diff --git a/docs/source/_static/images/logo_anr.jpg b/docs/source/_static/images/logo_anr.jpg new file mode 100644 index 0000000..dcef212 Binary files /dev/null and b/docs/source/_static/images/logo_anr.jpg differ diff --git a/docs/source/_static/images/logo_cnrs.jpg b/docs/source/_static/images/logo_cnrs.jpg new file mode 100644 index 0000000..902cf6f Binary files /dev/null and b/docs/source/_static/images/logo_cnrs.jpg differ diff --git a/docs/source/contributors.rst b/docs/source/contributors.rst new file mode 100644 index 0000000..f0acea6 --- /dev/null +++ b/docs/source/contributors.rst @@ -0,0 +1,6 @@ +Contributors +============ + +.. include:: ../../CONTRIBUTORS.md + :parser: myst_parser.sphinx_ + :start-line: 2 diff --git a/docs/source/index.rst b/docs/source/index.rst index 7ff7d22..3d53ef4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,6 +22,7 @@ Contents auto_examples/index releases .github/CONTRIBUTING + contributors .github/CODE_OF_CONDUCT diff --git a/docs/source/releases.rst b/docs/source/releases.rst index 8250a4d..b2c7a44 100644 --- a/docs/source/releases.rst +++ b/docs/source/releases.rst @@ -3,4 +3,4 @@ Releases .. include:: ../../RELEASES.md :parser: myst_parser.sphinx_ - :start-line: 3 + :start-line: 2 diff --git a/examples/others/plot_logo.py b/examples/others/plot_logo.py index 9414371..bb4f640 100644 --- a/examples/others/plot_logo.py +++ b/examples/others/plot_logo.py @@ -18,7 +18,7 @@ matplotlib and ploting teh solution of the EMD solver from POT. # sphinx_gallery_thumbnail_number = 1 -# %% +# %% Load modules import numpy as np import matplotlib.pyplot as pl import ot @@ -36,21 +36,21 @@ p2 = np.array([[1.5, 6], [2, 4], [2, 5], [1.5, 3], [0.5, 2], [.5, 1], ]) o1 = np.array([[0, 6.], [-1, 5], [-1.5, 4], [-1.5, 3], [-1, 2], [0, 1], ]) o2 = np.array([[1, 6.], [2, 5], [2.5, 4], [2.5, 3], [2, 2], [1, 1], ]) -# scaling and translation for letter O +# Scaling and translation for letter O o1[:, 0] += 6.4 o2[:, 0] += 6.4 o1[:, 0] *= 0.6 o2[:, 0] *= 0.6 -# letter T +# Letter T t1 = np.array([[-1, 6.], [-1, 5], [0, 4], [0, 3], [0, 2], [0, 1], ]) t2 = np.array([[1.5, 6.], [1.5, 5], [0.5, 4], [0.5, 3], [0.5, 2], [0.5, 1], ]) -# translatin the T +# Translating the T t1[:, 0] += 7.1 t2[:, 0] += 7.1 -# Cocatenate all letters +# Concatenate all letters x1 = np.concatenate((p1, o1, t1), axis=0) x2 = np.concatenate((p2, o2, t2), axis=0) diff --git a/ot/__init__.py b/ot/__init__.py index c5e1967..86ed94e 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -51,7 +51,7 @@ from .factored import factored_optimal_transport # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.8.2dev" +__version__ = "0.8.2" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', -- cgit v1.2.3