From 1511ecc67692773f016cc94190785faa33eab2d5 Mon Sep 17 00:00:00 2001 From: "D.J. Sutherland" Date: Wed, 13 May 2020 15:16:21 -0500 Subject: rename I to img in examples, since flake8 complains now (#176) --- examples/domain-adaptation/plot_otda_color_images.py | 8 ++++---- examples/domain-adaptation/plot_otda_linear_mapping.py | 8 ++++---- examples/domain-adaptation/plot_otda_mapping_colors_images.py | 8 ++++---- examples/gromov/plot_gromov_barycenter.py | 4 ++-- 4 files changed, 14 insertions(+), 14 deletions(-) (limited to 'examples') diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 929365e..d70f1fc 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -27,9 +27,9 @@ import ot r = np.random.RandomState(42) -def im2mat(I): +def im2mat(img): """Converts an image to matrix (one pixel per line)""" - return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) + return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) def mat2im(X, shape): @@ -37,8 +37,8 @@ def mat2im(X, shape): return X.reshape(shape) -def minmax(I): - return np.clip(I, 0, 1) +def minmax(img): + return np.clip(img, 0, 1) ############################################################################## diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index dbf16b8..be47510 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -79,9 +79,9 @@ pl.show() # --------------- -def im2mat(I): +def im2mat(img): """Converts and image to matrix (one pixel per line)""" - return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) + return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) def mat2im(X, shape): @@ -89,8 +89,8 @@ def mat2im(X, shape): return X.reshape(shape) -def minmax(I): - return np.clip(I, 0, 1) +def minmax(img): + return np.clip(img, 0, 1) # Loading images diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index ee5c8b0..aa41d22 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -29,9 +29,9 @@ import ot r = np.random.RandomState(42) -def im2mat(I): +def im2mat(img): """Converts and image to matrix (one pixel per line)""" - return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) + return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) def mat2im(X, shape): @@ -39,8 +39,8 @@ def mat2im(X, shape): return X.reshape(shape) -def minmax(I): - return np.clip(I, 0, 1) +def minmax(img): + return np.clip(img, 0, 1) ############################################################################## diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index f6f031a..e2d88ba 100755 --- a/examples/gromov/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -84,9 +84,9 @@ def smacof_mds(C, dim, max_iter=3000, eps=1e-9): # The four distributions are constructed from 4 simple images -def im2mat(I): +def im2mat(img): """Converts and image to matrix (one pixel per line)""" - return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) + return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2] -- cgit v1.2.3 From a8acfc31d3b4fd478ccb8f367549d5ebabca0d5c Mon Sep 17 00:00:00 2001 From: aboisbunon Date: Mon, 24 Aug 2020 15:58:29 +0200 Subject: [WIP] add introductory example of OT, EMD and Sinkhorn (#191) * add introductory example of OT, EMD and Sinkhorn * improve figure and try complying with pep8 * autopep8 * change markdown elements to rst * try solving issue with images * fix issue with images * add a section on varying the sinkhorn hyperparameter * add Sinkhorn algorithm and discussion for comparison between EMD and Sinkhorn * autopep8 again * add subsections and modify figure sizes/shapes * fix bug with print * correct some typos * remove computational time comparison * autopep8 again... Co-authored-by: Aurelie Boisbunon --- data/manhattan.npz | Bin 0 -> 2320938 bytes docs/source/auto_examples/images/bak.png | Bin 0 -> 304669 bytes docs/source/auto_examples/images/sinkhorn.png | Bin 0 -> 37204 bytes examples/plot_Intro_OT.py | 373 ++++++++++++++++++++++++++ 4 files changed, 373 insertions(+) create mode 100644 data/manhattan.npz create mode 100644 docs/source/auto_examples/images/bak.png create mode 100644 docs/source/auto_examples/images/sinkhorn.png create mode 100644 examples/plot_Intro_OT.py (limited to 'examples') diff --git a/data/manhattan.npz b/data/manhattan.npz new file mode 100644 index 0000000..37808fb Binary files /dev/null and b/data/manhattan.npz differ diff --git a/docs/source/auto_examples/images/bak.png b/docs/source/auto_examples/images/bak.png new file mode 100644 index 0000000..25e7e8e Binary files /dev/null and b/docs/source/auto_examples/images/bak.png differ diff --git a/docs/source/auto_examples/images/sinkhorn.png b/docs/source/auto_examples/images/sinkhorn.png new file mode 100644 index 0000000..e003e13 Binary files /dev/null and b/docs/source/auto_examples/images/sinkhorn.png differ diff --git a/examples/plot_Intro_OT.py b/examples/plot_Intro_OT.py new file mode 100644 index 0000000..2e2c6fd --- /dev/null +++ b/examples/plot_Intro_OT.py @@ -0,0 +1,373 @@ +# coding: utf-8 +""" +============================================= +Introduction to Optimal Transport with Python +============================================= + +This example gives an introduction on how to use Optimal Transport in Python. + +""" + +# Author: Remi Flamary, Nicolas Courty, Aurelie Boisbunon +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 1 + +############################################################################## +# POT Python Optimal Transport Toolbox +# ------------------------------------ +# +# POT installation +# ``````````````````` +# +# * Install with pip:: +# +# pip install pot +# * Install with conda:: +# +# conda install -c conda-forge pot +# +# Import the toolbox +# ``````````````````` +# + +import numpy as np # always need it +import pylab as pl # do the plots + +import ot # ot + +import time + +############################################################################## +# Getting help +# ````````````` +# +# Online documentation : ``_ +# +# Or inline help: +# + +help(ot.dist) + + +############################################################################## +# First OT Problem +# ---------------- +# +# We will solve the Bakery/Cafés problem of transporting croissants from a +# number of Bakeries to Cafés in a City (in this case Manhattan). We did a +# quick google map search in Manhattan for bakeries and Cafés: +# +# .. image:: images/bak.png +# :align: center +# :alt: bakery-cafe-manhattan +# :width: 600px +# :height: 280px +# +# We extracted from this search their positions and generated fictional +# production and sale number (that both sum to the same value). +# +# We have acess to the position of Bakeries ``bakery_pos`` and their +# respective production ``bakery_prod`` which describe the source +# distribution. The Cafés where the croissants are sold are defined also by +# their position ``cafe_pos`` and ``cafe_prod``, and describe the target +# distribution. For fun we also provide a +# map ``Imap`` that will illustrate the position of these shops in the city. +# +# +# Now we load the data +# +# + +data = np.load('../data/manhattan.npz') + +bakery_pos = data['bakery_pos'] +bakery_prod = data['bakery_prod'] +cafe_pos = data['cafe_pos'] +cafe_prod = data['cafe_prod'] +Imap = data['Imap'] + +print('Bakery production: {}'.format(bakery_prod)) +print('Cafe sale: {}'.format(cafe_prod)) +print('Total croissants : {}'.format(cafe_prod.sum())) + + +############################################################################## +# Plotting bakeries in the city +# ----------------------------- +# +# Next we plot the position of the bakeries and cafés on the map. The size of +# the circle is proportional to their production. +# + +pl.figure(1, (7, 6)) +pl.clf() +pl.imshow(Imap, interpolation='bilinear') # plot the map +pl.scatter(bakery_pos[:, 0], bakery_pos[:, 1], s=bakery_prod, c='r', ec='k', label='Bakeries') +pl.scatter(cafe_pos[:, 0], cafe_pos[:, 1], s=cafe_prod, c='b', ec='k', label='Cafés') +pl.legend() +pl.title('Manhattan Bakeries and Cafés') + + +############################################################################## +# Cost matrix +# ----------- +# +# +# We can now compute the cost matrix between the bakeries and the cafés, which +# will be the transport cost matrix. This can be done using the +# `ot.dist `_ function that +# defaults to squared Euclidean distance but can return other things such as +# cityblock (or Manhattan distance). +# + +C = ot.dist(bakery_pos, cafe_pos) + +labels = [str(i) for i in range(len(bakery_prod))] +f = pl.figure(2, (14, 7)) +pl.clf() +pl.subplot(121) +pl.imshow(Imap, interpolation='bilinear') # plot the map +for i in range(len(cafe_pos)): + pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b', + fontsize=14, fontweight='bold', ha='center', va='center') +for i in range(len(bakery_pos)): + pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r', + fontsize=14, fontweight='bold', ha='center', va='center') +pl.title('Manhattan Bakeries and Cafés') + +ax = pl.subplot(122) +im = pl.imshow(C, cmap="coolwarm") +pl.title('Cost matrix') +cbar = pl.colorbar(im, ax=ax, shrink=0.5, use_gridspec=True) +cbar.ax.set_ylabel("cost", rotation=-90, va="bottom") + +pl.xlabel('Cafés') +pl.ylabel('Bakeries') +pl.tight_layout() + + +############################################################################## +# The red cells in the matrix image show the bakeries and cafés that are +# further away, and thus more costly to transport from one to the other, while +# the blue ones show those that are very close to each other, with respect to +# the squared Euclidean distance. + + +############################################################################## +# Solving the OT problem with `ot.emd `_ +# ----------------------------------------------------------------------------------- + +start = time.time() +ot_emd = ot.emd(bakery_prod, cafe_prod, C) +time_emd = time.time() - start + +############################################################################## +# The function returns the transport matrix, which we can then visualize (next section). + +############################################################################## +# Transportation plan vizualization +# ````````````````````````````````` +# +# A good vizualization of the OT matrix in the 2D plane is to denote the +# transportation of mass between a Bakery and a Café by a line. This can easily +# be done with a double ``for`` loop. +# +# In order to make it more interpretable one can also use the ``alpha`` +# parameter of plot and set it to ``alpha=G[i,j]/G.max()``. + +# Plot the matrix and the map +f = pl.figure(3, (14, 7)) +pl.clf() +pl.subplot(121) +pl.imshow(Imap, interpolation='bilinear') # plot the map +for i in range(len(bakery_pos)): + for j in range(len(cafe_pos)): + pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]], [bakery_pos[i, 1], cafe_pos[j, 1]], + '-k', lw=3. * ot_emd[i, j] / ot_emd.max()) +for i in range(len(cafe_pos)): + pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b', fontsize=14, + fontweight='bold', ha='center', va='center') +for i in range(len(bakery_pos)): + pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r', fontsize=14, + fontweight='bold', ha='center', va='center') +pl.title('Manhattan Bakeries and Cafés') + +ax = pl.subplot(122) +im = pl.imshow(ot_emd) +for i in range(len(bakery_prod)): + for j in range(len(cafe_prod)): + text = ax.text(j, i, '{0:g}'.format(ot_emd[i, j]), + ha="center", va="center", color="w") +pl.title('Transport matrix') + +pl.xlabel('Cafés') +pl.ylabel('Bakeries') +pl.tight_layout() + +############################################################################## +# The transport matrix gives the number of croissants that can be transported +# from each bakery to each café. We can see that the bakeries only need to +# transport croissants to one or two cafés, the transport matrix being very +# sparse. + +############################################################################## +# OT loss and dual variables +# -------------------------- +# +# The resulting wasserstein loss loss is of the form: +# +# .. math:: +# W=\sum_{i,j}\gamma_{i,j}C_{i,j} +# +# where :math:`\gamma` is the optimal transport matrix. +# + +W = np.sum(ot_emd * C) +print('Wasserstein loss (EMD) = {0:.2f}'.format(W)) + +############################################################################## +# Regularized OT with Sinkhorn +# ---------------------------- +# +# The Sinkhorn algorithm is very simple to code. You can implement it directly +# using the following pseudo-code +# +# .. image:: images/sinkhorn.png +# :align: center +# :alt: Sinkhorn algorithm +# :width: 440px +# :height: 240px +# +# In this algorithm, :math:`\oslash` corresponds to the element-wise division. +# +# An alternative is to use the POT toolbox with +# `ot.sinkhorn `_ +# +# Be careful of numerical problems. A good pre-processing for Sinkhorn is to +# divide the cost matrix ``C`` by its maximum value. + +############################################################################## +# Algorithm +# ````````` + +# Compute Sinkhorn transport matrix from algorithm +reg = 0.1 +K = np.exp(-C / C.max() / reg) +nit = 100 +u = np.ones((len(bakery_prod), )) +for i in range(1, nit): + v = cafe_prod / np.dot(K.T, u) + u = bakery_prod / (np.dot(K, v)) +ot_sink_algo = np.atleast_2d(u).T * (K * v.T) # Equivalent to np.dot(np.diag(u), np.dot(K, np.diag(v))) + +# Compute Sinkhorn transport matrix with POT +ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg, M=C / C.max()) + +# Difference between the 2 +print('Difference between algo and ot.sinkhorn = {0:.2g}'.format(np.sum(np.power(ot_sink_algo - ot_sinkhorn, 2)))) + +############################################################################## +# Plot the matrix and the map +# ``````````````````````````` + +print('Min. of Sinkhorn\'s transport matrix = {0:.2g}'.format(np.min(ot_sinkhorn))) + +f = pl.figure(4, (13, 6)) +pl.clf() +pl.subplot(121) +pl.imshow(Imap, interpolation='bilinear') # plot the map +for i in range(len(bakery_pos)): + for j in range(len(cafe_pos)): + pl.plot([bakery_pos[i, 0], cafe_pos[j, 0]], + [bakery_pos[i, 1], cafe_pos[j, 1]], + '-k', lw=3. * ot_sinkhorn[i, j] / ot_sinkhorn.max()) +for i in range(len(cafe_pos)): + pl.text(cafe_pos[i, 0], cafe_pos[i, 1], labels[i], color='b', + fontsize=14, fontweight='bold', ha='center', va='center') +for i in range(len(bakery_pos)): + pl.text(bakery_pos[i, 0], bakery_pos[i, 1], labels[i], color='r', + fontsize=14, fontweight='bold', ha='center', va='center') +pl.title('Manhattan Bakeries and Cafés') + +ax = pl.subplot(122) +im = pl.imshow(ot_sinkhorn) +for i in range(len(bakery_prod)): + for j in range(len(cafe_prod)): + text = ax.text(j, i, np.round(ot_sinkhorn[i, j], 1), + ha="center", va="center", color="w") +pl.title('Transport matrix') + +pl.xlabel('Cafés') +pl.ylabel('Bakeries') +pl.tight_layout() + + +############################################################################## +# We notice right away that the matrix is not sparse at all with Sinkhorn, +# each bakery delivering croissants to all 5 cafés with that solution. Also, +# this solution gives a transport with fractions, which does not make sense +# in the case of croissants. This was not the case with EMD. + +############################################################################## +# Varying the regularization parameter in Sinkhorn +# ```````````````````````````````````````````````` +# + +reg_parameter = np.logspace(-3, 0, 20) +W_sinkhorn_reg = np.zeros((len(reg_parameter), )) +time_sinkhorn_reg = np.zeros((len(reg_parameter), )) + +f = pl.figure(5, (14, 5)) +pl.clf() +max_ot = 100 # plot matrices with the same colorbar +for k in range(len(reg_parameter)): + start = time.time() + ot_sinkhorn = ot.sinkhorn(bakery_prod, cafe_prod, reg=reg_parameter[k], M=C / C.max()) + time_sinkhorn_reg[k] = time.time() - start + + if k % 4 == 0 and k > 0: # we only plot a few + ax = pl.subplot(1, 5, k / 4) + im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot) + pl.title('reg={0:.2g}'.format(reg_parameter[k])) + pl.xlabel('Cafés') + pl.ylabel('Bakeries') + + # Compute the Wasserstein loss for Sinkhorn, and compare with EMD + W_sinkhorn_reg[k] = np.sum(ot_sinkhorn * C) +pl.tight_layout() + + +############################################################################## +# This series of graph shows that the solution of Sinkhorn starts with something +# very similar to EMD (although not sparse) for very small values of the +# regularization parameter, and tends to a more uniform solution as the +# regularization parameter increases. +# + +############################################################################## +# Wasserstein loss and computational time +# ``````````````````````````````````````` +# + +# Plot the matrix and the map +f = pl.figure(6, (4, 4)) +pl.clf() +pl.title("Comparison between Sinkhorn and EMD") + +pl.plot(reg_parameter, W_sinkhorn_reg, 'o', label="Sinkhorn") +XLim = pl.xlim() +pl.plot(XLim, [W, W], '--k', label="EMD") +pl.legend() +pl.xlabel("reg") +pl.ylabel("Wasserstein loss") + +############################################################################## +# In this last graph, we show the impact of the regularization parameter on +# the Wasserstein loss. We can see that higher +# values of ``reg`` leads to a much higher Wasserstein loss. +# +# The Wasserstein loss of EMD is displayed for +# comparison. The Wasserstein loss of Sinkhorn can be a little lower than that +# of EMD for low values of ``reg``, but it quickly gets much higher. +# -- cgit v1.2.3 From 78b44af2434f494c8f9e4c8c91003fbc0e1d4415 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Thu, 22 Oct 2020 09:28:53 +0100 Subject: [MRG] Sliced wasserstein (#203) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * example for log treatment in bregman.py * Improve doc * Revert "example for log treatment in bregman.py" This reverts commit 9f51c14e * Add comments by Flamary * Delete repetitive description * Added raw string to avoid pbs with backslashes * Implements sliced wasserstein * Changed formatting of string for py3.5 support * Docstest, expected 0.0 and not 0. * Adressed comments by @rflamary * No 3d plot here * add sliced to the docs * Incorporate comments by @rflamary * add link to pdf Co-authored-by: Rémi Flamary --- README.md | 4 + docs/source/all.rst | 1 + examples/sliced-wasserstein/README.txt | 4 + examples/sliced-wasserstein/plot_variance.py | 84 ++++++++++++++++ ot/__init__.py | 3 +- ot/sliced.py | 144 +++++++++++++++++++++++++++ test/test_sliced.py | 85 ++++++++++++++++ 7 files changed, 324 insertions(+), 1 deletion(-) create mode 100644 examples/sliced-wasserstein/README.txt create mode 100644 examples/sliced-wasserstein/plot_variance.py create mode 100644 ot/sliced.py create mode 100644 test/test_sliced.py (limited to 'examples') diff --git a/README.md b/README.md index e3598f1..6fe528a 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. POT provides the following Machine Learning related solvers: @@ -180,6 +181,7 @@ The contributors to this library are * [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) +* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -263,3 +265,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276. [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + +[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 diff --git a/docs/source/all.rst b/docs/source/all.rst index d7b878f..f1f7075 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -27,6 +27,7 @@ API and modules stochastic unbalanced partial + sliced .. autosummary:: :toctree: ../modules/generated/ diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt new file mode 100644 index 0000000..a575345 --- /dev/null +++ b/examples/sliced-wasserstein/README.txt @@ -0,0 +1,4 @@ + + +Sliced Wasserstein Distance +--------------------------- \ No newline at end of file diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py new file mode 100644 index 0000000..f3deeff --- /dev/null +++ b/examples/sliced-wasserstein/plot_variance.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +""" +============================== +2D Sliced Wasserstein Distance +============================== + +This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31]. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +""" + +# Author: Adrien Corenflos +# +# License: MIT License + +import matplotlib.pylab as pl +import numpy as np + +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 500 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +############################################################################## +# Plot data +# --------- + +# %% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +################################################################################### +# Compute Sliced Wasserstein distance for different seeds and number of projections +# ----------- + +n_seed = 50 +n_projections_arr = np.logspace(0, 3, 25, dtype=int) +res = np.empty((n_seed, 25)) + +# %% Compute statistics +for seed in range(n_seed): + for i, n_projections in enumerate(n_projections_arr): + res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +################################################################################### +# Plot Sliced Wasserstein Distance +# ----------- + +pl.figure(2) +pl.plot(n_projections_arr, res_mean, label="SWD") +pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of projections") +pl.ylabel("Distance") +pl.title('Sliced Wasserstein Distance with 95% confidence inverval') + +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 0e6e2e2..ec3ede2 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -39,6 +39,7 @@ from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 from .da import sinkhorn_lpl1_mm +from .sliced import sliced_wasserstein_distance # utils functions from .utils import dist, unif, tic, toc, toq @@ -50,4 +51,4 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets' 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2'] + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance'] diff --git a/ot/sliced.py b/ot/sliced.py new file mode 100644 index 0000000..4792576 --- /dev/null +++ b/ot/sliced.py @@ -0,0 +1,144 @@ +""" +Sliced Wasserstein Distance. + +""" + +# Author: Adrien Corenflos +# +# License: MIT License + + +import numpy as np + + +def get_random_projections(n_projections, d, seed=None): + r""" + Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})` + + Parameters + ---------- + n_projections : int + number of samples requested + d : int + dimension of the space + seed: int or RandomState, optional + Seed used for numpy random number generator + + Returns + ------- + out: ndarray, shape (n_projections, d) + The uniform unit vectors on the sphere + + Examples + -------- + >>> n_projections = 100 + >>> d = 5 + >>> projs = get_random_projections(n_projections, d) + >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE + True + + """ + + if not isinstance(seed, np.random.RandomState): + random_state = np.random.RandomState(seed) + else: + random_state = seed + + projections = random_state.normal(0., 1., [n_projections, d]) + norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) + projections = projections / norm + return projections + + +def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): + r""" + Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance + + .. math:: + \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}} + + where : + + - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + seed: int or RandomState or None, optional + Seed used for numpy random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + + .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + """ + from .lp import emd2_1d + + X_s = np.asanyarray(X_s) + X_t = np.asanyarray(X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + + if a is None: + a = np.full(n, 1 / n) + if b is None: + b = np.full(m, 1 / m) + + d = X_s.shape[1] + + projections = get_random_projections(n_projections, d, seed) + + X_s_projections = np.dot(projections, X_s.T) + X_t_projections = np.dot(projections, X_t.T) + + if log: + projected_emd = np.empty(n_projections) + else: + projected_emd = None + + res = 0. + + for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)): + emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) + if projected_emd is not None: + projected_emd[i] = emd + res += emd + + res = (res / n_projections) ** 0.5 + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res diff --git a/test/test_sliced.py b/test/test_sliced.py new file mode 100644 index 0000000..a07d975 --- /dev/null +++ b/test/test_sliced.py @@ -0,0 +1,85 @@ +"""Tests for module sliced""" + +# Author: Adrien Corenflos +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.sliced import get_random_projections + + +def test_get_random_projections(): + rng = np.random.RandomState(0) + projections = get_random_projections(1000, 50, rng) + np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) + + +def test_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + + +def test_sliced_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert len(projections) == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 1) + a = rng.uniform(0, 1, n) + a /= a.sum() + y = rng.randn(m, 1) + u = ot.utils.unif(m) + res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) + expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) + np.testing.assert_almost_equal(res ** 2, expected) -- cgit v1.2.3 From 93785eba11b59d544f1edde6661e93ee587148ee Mon Sep 17 00:00:00 2001 From: Laetitia Chapel Date: Thu, 22 Oct 2020 10:58:31 +0200 Subject: [MRG] Fix bugs for partial OT (#215) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bugfix * update refs partial OT * fixes small typos in plot_partial_wass_and_gromov * fix small bugs in partial.py * update README * pep8 bugfix * modif doctest * fix bugtests * update on test_partial and test on the numerical precision on ot/partial * resolve merge pb Co-authored-by: Rémi Flamary --- README.md | 2 +- .../plot_partial_wass_and_gromov.py | 23 ++++--- ot/partial.py | 71 +++++++++++++--------- test/test_partial.py | 6 +- 4 files changed, 60 insertions(+), 42 deletions(-) (limited to 'examples') diff --git a/README.md b/README.md index 6fe528a..238faed 100644 --- a/README.md +++ b/README.md @@ -262,7 +262,7 @@ You can also post bug reports and feature requests in Github issues. Make sure t [28] Caffarelli, L. A., McCann, R. J. (2010). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730. -[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276. +[29] Chapel, L., Alaya, M., Gasso, G. (2020). [Partial Optimal Transport with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), Advances in Neural Information Processing Systems (NeurIPS), 2020. [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py index 0c5cbf9..ac4194c 100755 --- a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -4,7 +4,7 @@ Partial Wasserstein and Gromov-Wasserstein example ================================================== -This example is designed to show how to use the Partial (Gromov-)Wassertsein +This example is designed to show how to use the Partial (Gromov-)Wasserstein distance computation in POT. """ @@ -123,11 +123,12 @@ C1 = sp.spatial.distance.cdist(xs, xs) C2 = sp.spatial.distance.cdist(xt, xt) # transport 100% of the mass -print('-----m = 1') +print('------m = 1') m = 1 res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, - m=m, log=True) + m=m, log=True, + verbose=True) print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist'])) print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist'])) @@ -136,18 +137,20 @@ pl.figure(1, (10, 5)) pl.title("mass to be transported m = 1") pl.subplot(1, 2, 1) pl.imshow(res0, cmap='jet') -pl.title('Wasserstein') +pl.title('Gromov-Wasserstein') pl.subplot(1, 2, 2) pl.imshow(res, cmap='jet') -pl.title('Entropic Wasserstein') +pl.title('Entropic Gromov-Wasserstein') pl.show() # transport 2/3 of the mass -print('-----m = 2/3') +print('------m = 2/3') m = 2 / 3 -res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True) +res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True, + verbose=True) res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10, - m=m, log=True) + m=m, log=True, + verbose=True) print('Partial Wasserstein distance (m = 2/3): ' + str(log0['partial_gw_dist'])) @@ -158,8 +161,8 @@ pl.figure(1, (10, 5)) pl.title("mass to be transported m = 2/3") pl.subplot(1, 2, 1) pl.imshow(res0, cmap='jet') -pl.title('Partial Wasserstein') +pl.title('Partial Gromov-Wasserstein') pl.subplot(1, 2, 2) pl.imshow(res, cmap='jet') -pl.title('Entropic partial Wasserstein') +pl.title('Entropic partial Gromov-Wasserstein') pl.show() diff --git a/ot/partial.py b/ot/partial.py index eb707d8..814d779 100755 --- a/ot/partial.py +++ b/ot/partial.py @@ -230,9 +230,9 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): .. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in optimal transport and Monge-Ampere obstacle problems. Annals of mathematics, 673-730. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. See Also -------- @@ -254,7 +254,7 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies) a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies) M_extended = np.zeros((len(a_extended), len(b_extended))) - M_extended[-1, -1] = np.max(M) * 1e5 + M_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 M_extended[:len(a), :len(b)] = M gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True, @@ -344,14 +344,13 @@ def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs): .. [28] Caffarelli, L. A., & McCann, R. J. (2010) Free boundaries in optimal transport and Monge-Ampere obstacle problems. Annals of mathematics, 673-730. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ partial_gw, log_w = partial_wasserstein(a, b, M, m, nb_dummies, log=True, **kwargs) - log_w['T'] = partial_gw if log: @@ -501,14 +500,14 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, >>> np.round(partial_gromov_wasserstein(C1, C2, a, b, m=0.25),2) array([[0. , 0. , 0. , 0. ], [0. , 0. , 0. , 0. ], - [0. , 0. , 0. , 0. ], - [0. , 0. , 0. , 0.25]]) + [0. , 0. , 0.25, 0. ], + [0. , 0. , 0. , 0. ]]) References ---------- - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ @@ -530,20 +529,18 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, cpt = 0 err = 1 - eps = 1e-20 + if log: log = {'err': []} while (err > tol and cpt < numItermax): - Gprev = G0 + Gprev = np.copy(G0) M = gwgrad_partial(C1, C2, G0) - M[M < eps] = np.quantile(M, thres) - M_emd = np.zeros(dim_G_extended) M_emd[:len(p), :len(q)] = M - M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e5 + M_emd[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2 M_emd = np.asarray(M_emd, dtype=np.float64) Gc, logemd = emd(p_extended, q_extended, M_emd, log=True, **kwargs) @@ -565,6 +562,22 @@ def partial_gromov_wasserstein(C1, C2, p, q, m=None, nb_dummies=1, G0=None, print('{:5d}|{:8e}|{:8e}'.format(cpt, err, gwloss_partial(C1, C2, G0))) + deltaG = G0 - Gprev + a = gwloss_partial(C1, C2, deltaG) + b = 2 * np.sum(M * deltaG) + if b > 0: # due to numerical precision + gamma = 0 + cpt = numItermax + elif a > 0: + gamma = min(1, np.divide(-b, 2.0 * a)) + else: + if (a + b) < 0: + gamma = 1 + else: + gamma = 0 + cpt = numItermax + + G0 = Gprev + gamma * deltaG cpt += 1 if log: @@ -665,9 +678,9 @@ def partial_gromov_wasserstein2(C1, C2, p, q, m=None, nb_dummies=1, G0=None, References ---------- - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ @@ -887,12 +900,12 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, >>> y = np.array([3,2,98,199]).reshape((-1,1)) >>> C1 = sp.spatial.distance.cdist(x, x) >>> C2 = sp.spatial.distance.cdist(y, y) - >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b,50), 2) + >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50), 2) array([[0.12, 0.13, 0. , 0. ], [0.13, 0.12, 0. , 0. ], [0. , 0. , 0.25, 0. ], [0. , 0. , 0. , 0.25]]) - >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50, m=0.25), 2) + >>> np.round(entropic_partial_gromov_wasserstein(C1, C2, a, b, 50,0.25), 2) array([[0.02, 0.03, 0. , 0.03], [0.03, 0.03, 0. , 0.03], [0. , 0. , 0.03, 0. ], @@ -910,9 +923,9 @@ def entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, m=None, G0=None, .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. See Also -------- @@ -1044,9 +1057,9 @@ def entropic_partial_gromov_wasserstein2(C1, C2, p, q, reg, m=None, G0=None, .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). "Partial Gromov- - Wasserstein with Applications on Positive-Unlabeled Learning". - arXiv preprint arXiv:2002.08276. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal + Transport with Applications on Positive-Unlabeled Learning". + NeurIPS. """ partial_gw, log_gw = entropic_partial_gromov_wasserstein(C1, C2, p, q, reg, diff --git a/test/test_partial.py b/test/test_partial.py index 510e081..121f345 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -51,10 +51,12 @@ def test_raise_errors(): ot.partial.partial_gromov_wasserstein(M, M, p, q, m=-1, log=True) with pytest.raises(ValueError): - ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, log=True) + ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=2, + log=True) with pytest.raises(ValueError): - ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, log=True) + ot.partial.entropic_partial_gromov_wasserstein(M, M, p, q, reg=1, m=-1, + log=True) def test_partial_wasserstein_lagrange(): -- cgit v1.2.3 From f6139428e70ce964de3bef703ef13aa701a83620 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 22 Dec 2020 18:35:40 +0100 Subject: [WIP] Update documentation "Why OT" section (#220) * add some text + discussion sinkhorn * stating wrk on why POT * fix sphinx warnings + make html-noplot * discussion when not to use POT * add discussion which sinkhorn * edits on quickstart * more * remove warnings :any: * more * done * remove ref Co-authored-by: Alexandre Gramfort --- docs/Makefile | 5 + docs/source/quickstart.rst | 448 ++++++++++++++------- .../barycenters/plot_free_support_barycenter.py | 4 +- examples/domain-adaptation/plot_otda_jcpot.py | 4 +- examples/gromov/plot_barycenter_fgw.py | 2 +- examples/gromov/plot_fgw.py | 10 +- examples/plot_OT_1D_smooth.py | 2 +- examples/plot_OT_2D_samples.py | 2 +- examples/sliced-wasserstein/plot_variance.py | 16 +- examples/unbalanced-partial/plot_UOT_1D.py | 3 +- ot/__init__.py | 10 +- ot/bregman.py | 30 ++ 12 files changed, 367 insertions(+), 169 deletions(-) (limited to 'examples') diff --git a/docs/Makefile b/docs/Makefile index 3511a59..9892785 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -57,6 +57,11 @@ html: @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." +html-noplot: + $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + .PHONY: dirhtml dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 8d8b03f..cf5d6aa 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -7,19 +7,170 @@ to use for different problems related to optimal transport (OT) and machine learning. We refer when we can to concrete examples in the documentation that are also available as notebooks on the POT Github. -This document is not a tutorial on numerical optimal transport. For this we strongly -recommend to read the very nice book [15]_ . +.. note:: + + For a good introduction to numerical optimal transport we refer the reader + to `the book `_ by Peyré and Cuturi + [15]_. For more detailed introduction to OT and how it can be used + in ML applications we refer the reader to the following `OTML tutorial + `_. + + + +Why Optimal Transport ? +----------------------- + + +When to use OT +^^^^^^^^^^^^^^ + +Optimal Transport (OT) is a mathematical problem introduced by Gaspard Monge in +1781 that aim at finding the most efficient way to move mass between +distributions. The cost of moving a unit of mass between two positions is called +the ground cost and the objective is to minimize the overall cost of moving one +mass distribution onto another one. The optimization problem can be expressed +for two distributions :math:`\mu_s` and :math:`\mu_t` as + +.. math:: + \min_{m, m \# \mu_s = \mu_t} \int c(x,m(x))d\mu_s(x) , + +where :math:`c(\cdot,\cdot)` is the ground cost and the constraint +:math:`m \# \mu_s = \mu_t` ensures that :math:`\mu_s` is completely transported to :math:`\mu_t`. +This problem is particularly difficult to solve because of this constraint and +has been replaced in practice (on discrete distributions) by a +linear program easier to solve. It corresponds to the Kantorovitch formulation +where the Monge mapping :math:`m` is replaced by a joint distribution +(OT matrix expressed in the next section) (see :ref:`kantorovitch_solve`). + +From the optimization problem above we can see that there are two main aspects +to the OT solution that can be used in practical applications: + +- The optimal value (Wasserstein distance): Measures similarity between distributions. +- The optimal mapping (Monge mapping, OT matrix): Finds correspondences between distributions. + + +In the first case, OT can be used to measure similarity between distributions +(or datasets), in this case the Wasserstein distance (the optimal value of the +problem) is used. In the second case one can be interested in the way the mass +is moved between the distributions (the mapping). This mapping can then be used +to transfer knowledge between distributions. + + +Wasserstein distance between distributions +"""""""""""""""""""""""""""""""""""""""""" + +OT is often used to measure similarity between distributions, especially +when they do not share the same support. When the support between the +distributions is disjoint OT-based Wasserstein distances compare favorably to +popular f-divergences including the popular Kullback-Leibler, Jensen-Shannon +divergences, and the Total Variation distance. What is particularly interesting +for data science applications is that one can compute meaningful sub-gradients +of the Wasserstein distance. For these reasons it became a very efficient tool +for machine learning applications that need to measure and optimize similarity +between empirical distributions. + + +Numerous contributions make use of this an approach is the machine learning (ML) +literature. For example OT was used for training `Generative +Adversarial Networks (GANs) `_ +in order to overcome the vanishing gradient problem. It has also +been used to find `discriminant `_ or +`robust `_ subspaces for a dataset. The +Wasserstein distance has also been used to measure `similarity between word +embeddings of documents `_ or +between `signals +`_ +or `spectra `_. + + + +OT for mapping estimation +""""""""""""""""""""""""" + +A very interesting aspect of OT problem is the OT mapping in itself. When +computing optimal transport between discrete distributions one output is the OT +matrix that will provide you with correspondences between the samples in each +distributions. + + +This correspondence is estimated with respect to the OT criterion and is found +in a non-supervised way, which makes it very interesting on problems of transfer +between datasets. It has been used to perform +`color transfer between images `_ or in +the context of `domain adaptation `_. +More recent applications include the use of extension of OT (Gromov-Wasserstein) +to find correspondences between languages in `word embeddings +`_. + + +When to use POT +^^^^^^^^^^^^^^^ + + +The main objective of POT is to provide OT solvers for the rapidly growing area +of OT in the context of machine learning. To this end we implement a number of +solvers that have been proposed in research papers. Doing so we aim to promote +reproducible research and foster novel developments. + + +One very important aspect of POT is its ability to be easily extended. For +instance we provide a very generic OT solver :any:`ot.optim.cg` that can solve +OT problems with any smooth/continuous regularization term making it +particularly practical for research purpose. Note that this generic solver has +been used to solve both graph Laplacian regularization OT and Gromov +Wasserstein [30]_. + + +.. note:: + + POT is originally designed to solve OT problems with Numpy interface and + is not yet compatible with Pytorch API. We are currently working on a torch + submodule that will provide OT solvers and losses for the most common deep + learning configurations. + + +When not to use POT +""""""""""""""""""" + +While POT has to the best of our knowledge one of the most efficient exact OT +solvers, it has not been designed to handle large scale OT problems. For +instance the memory cost for an OT problem is always :math:`\mathcal{O}(n^2)` in +memory because the cost matrix has to be computed. The exact solver in of time +complexity :math:`\mathcal{O}(n^3\log(n))` and the Sinkhorn solver has been +proven to be nearly :math:`\mathcal{O}(n^2)` which is still too complex for very +large scale solvers. + + +If you need to solve OT with large number of samples, we recommend to use +entropic regularization and memory efficient implementation of Sinkhorn as +proposed in `GeomLoss `_. This +implementation is compatible with Pytorch and can handle large number of +samples. Another approach to estimate the Wasserstein distance for very large +number of sample is to use the trick from `Wasserstein GAN +`_ that solves the problem +in the dual with a neural network estimating the dual variable. Note that in this +case you are only solving an approximation of the Wasserstein distance because +the 1-Lipschitz constraint on the dual cannot be enforced exactly (approximated +through filter thresholding or regularization). Finally note that in order to +avoid solving large scale OT problems, a number of recent approached minimized +the expected Wasserstein distance on minibtaches that is different from the +Wasserstein but has better computational and +`statistical properties `_. + Optimal transport and Wasserstein distance ------------------------------------------ .. note:: + In POT, most functions that solve OT or regularized OT problems have two versions that return the OT matrix or the value of the optimal solution. For - instance :any:`ot.emd` return the OT matrix and :any:`ot.emd2` return the + instance :any:`ot.emd` returns the OT matrix and :any:`ot.emd2` returns the Wassertsein distance. This approach has been implemented in practice for all - solvers that return an OT matrix (even Gromov-Wasserstsein) + solvers that return an OT matrix (even Gromov-Wasserstsein). + +.. _kantorovitch_solve: Solving optimal transport ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -28,30 +179,31 @@ The optimal transport problem between discrete distributions is often expressed as .. math:: - \gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \gamma^* = arg\min_{\gamma \in \mathbb{R}_+^{m\times n}} \quad \sum_{i,j}\gamma_{i,j}M_{i,j} s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0 -where : +where: -- :math:`M\in\mathbb{R}_+^{m\times n}` is the metric cost matrix defining the cost to move mass from bin :math:`a_i` to bin :math:`b_j`. -- :math:`a` and :math:`b` are histograms on the simplex (positive, sum to 1) that represent the -weights of each samples in the source an target distributions. + - :math:`M\in\mathbb{R}_+^{m\times n}` is the metric cost matrix defining the cost to move mass from bin :math:`a_i` to bin :math:`b_j`. + + - :math:`a` and :math:`b` are histograms on the simplex (positive, sum to 1) that represent the weights of each samples in the source an target distributions. Solving the linear program above can be done using the function :any:`ot.emd` that will return the optimal transport matrix :math:`\gamma^*`: .. code:: python - # a,b are 1D histograms (sum to 1 and positive) + # a and b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix - T=ot.emd(a,b,M) # exact linear program + T = ot.emd(a, b, M) # exact linear program -The method implemented for solving the OT problem is the network simplex, it is -implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the +The method implemented for solving the OT problem is the network simplex. It is +implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the solver is quite efficient and uses sparsity of the solution. .. hint:: + Examples of use for :any:`ot.emd` are available in : - :any:`auto_examples/plot_OT_2D_samples` @@ -62,10 +214,11 @@ solver is quite efficient and uses sparsity of the solution. Computing Wasserstein distance ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The value of the OT solution is often more of interest than the OT matrix : +The value of the OT solution is often more interesting than the OT matrix: .. math:: - OT(a,b)=\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + + OT(a,b) = \min_{\gamma \in \mathbb{R}_+^{m\times n}} \quad \sum_{i,j}\gamma_{i,j}M_{i,j} s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0 @@ -75,9 +228,9 @@ It can computed from an already estimated OT matrix with .. code:: python - # a,b are 1D histograms (sum to 1 and positive) + # a and b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix - W=ot.emd2(a,b,M) # Wasserstein distance / EMD value + W = ot.emd2(a, b, M) # Wasserstein distance / EMD value Note that the well known `Wasserstein distance `_ between distributions a and @@ -86,19 +239,19 @@ b is defined as .. math:: - W_p(a,b)=(\min_\gamma \sum_{i,j}\gamma_{i,j}\|x_i-y_j\|_p)^\frac{1}{p} + W_p(a,b)=(\min_{\gamma \in \mathbb{R}_+^{m\times n}} \sum_{i,j}\gamma_{i,j}\|x_i-y_j\|_p)^\frac{1}{p} s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0 This means that if you want to compute the :math:`W_2` you need to compute the square root of :any:`ot.emd2` when providing -:code:`M=ot.dist(xs,xt)` that use the squared euclidean distance by default. Computing -the :math:`W_1` wasserstein distance can be done directly with :any:`ot.emd2` -when providing :code:`M=ot.dist(xs,xt, metric='euclidean')` to use the euclidean +:code:`M = ot.dist(xs, xt)`, that uses the squared euclidean distance by default. Computing +the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2` +when providing :code:`M = ot.dist(xs, xt, metric='euclidean')` to use the Euclidean distance. - .. hint:: + An example of use for :any:`ot.emd2` is available in : - :any:`auto_examples/plot_compute_emd` @@ -123,9 +276,9 @@ Another special case for estimating OT and Monge mapping is between Gaussian distributions. In this case there exists a close form solution given in Remark 2.29 in [15]_ and the Monge mapping is an affine function and can be also computed from the covariances and means of the source and target -distributions. In the case when the finite sample dataset is supposed gaussian, we provide -:any:`ot.da.OT_mapping_linear` that returns the parameters for the Monge -mapping. +distributions. In the case when the finite sample dataset is supposed Gaussian, +we provide :any:`ot.da.OT_mapping_linear` that returns the parameters for the +Monge mapping. Regularized Optimal Transport @@ -136,7 +289,7 @@ computational and statistical properties. We address in this section the regularized OT problems that can be expressed as .. math:: - \gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \lambda\Omega(\gamma) + \gamma^* = arg\min_{\gamma \in \mathbb{R}_+^{m\times n}} \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \lambda\Omega(\gamma) s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0 @@ -175,8 +328,8 @@ solution of the resulting optimization problem can be expressed as: where :math:`u` and :math:`v` are vectors and :math:`K=\exp(-M/\lambda)` where the :math:`\exp` is taken component-wise. In order to solve the optimization -problem, on can use an alternative projection algorithm called Sinkhorn-Knopp that can be very -efficient for large values of regularization. +problem, one can use an alternative projection algorithm called Sinkhorn-Knopp +that can be very efficient for large values of regularization. The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and :any:`ot.sinkhorn2` that return respectively the OT matrix and the value of the @@ -184,10 +337,10 @@ linear term. Note that the regularization parameter :math:`\lambda` in the equation above is given to those functions with the parameter :code:`reg`. >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.sinkhorn(a,b,M,1) + >>> a = [.5, .5] + >>> b = [.5, .5] + >>> M = [[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) array([[ 0.36552929, 0.13447071], [ 0.13447071, 0.36552929]]) @@ -195,7 +348,7 @@ More details about the algorithms used are given in the following note. .. note:: The main function to solve entropic regularized OT is :any:`ot.sinkhorn`. - This function is a wrapper and the parameter :code:`method` help you select + This function is a wrapper and the parameter :code:`method` allows you to select the actual algorithm used to solve the problem: + :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the @@ -206,9 +359,11 @@ More details about the algorithms used are given in the following note. :any:`ot.bregman.sinkhorn_epsilon_scaling` the epsilon scaling version of the algorithm [9]_. + :code:`method='greenkhorn'` calls :any:`ot.bregman.greenkhorn` the - greedy sinkhorn verison of the algorithm [22]_. + greedy Sinkhorn version of the algorithm [22]_. + + :code:`method='screenkhorn'` calls :any:`ot.bregman.screenkhorn` the + screening sinkhorn version of the algorithm [26]_. - In addition to all those variants of sinkhorn, we have another + In addition to all those variants of Sinkhorn, we have another implementation solving the problem in the smooth dual or semi-dual in :any:`ot.smooth`. This solver uses the :any:`scipy.optimize.minimize` function to solve the smooth problem with :code:`L-BFGS-B` algorithm. Tu use @@ -216,12 +371,28 @@ More details about the algorithms used are given in the following note. :any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='kl'` to choose entropic/Kullbach Leibler regularization. + **Choosing a Sinkhorn solver** -Recently [23]_ introduced the sinkhorn divergence that build from entropic + By default and when using a regularization parameter that is not too small + the default Sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the Sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the Sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. + + + +Recently Genevay et al. [23]_ introduced the Sinkhorn divergence that build from entropic regularization to compute fast and differentiable geometric divergence between -empirical distributions. Note that we provide a function that compute directly -(with no need to pre compute the :code:`M` matrix) -the sinkhorn divergence for empirical distributions in +empirical distributions. Note that we provide a function that computes directly +(with no need to precompute the :code:`M` matrix) +the Sinkhorn divergence for empirical distributions in :any:`ot.bregman.empirical_sinkhorn_divergence`. Similarly one can compute the OT matrix and loss for empirical distributions with respectively :any:`ot.bregman.empirical_sinkhorn` and :any:`ot.bregman.empirical_sinkhorn2`. @@ -229,7 +400,7 @@ OT matrix and loss for empirical distributions with respectively Finally note that we also provide in :any:`ot.stochastic` several implementation of stochastic solvers for entropic regularized OT [18]_ [19]_. Those pure Python -implementations are not optimized for speed but provide a roust implementation +implementations are not optimized for speed but provide a robust implementation of algorithms in [18]_ [19]_. .. hint:: @@ -244,11 +415,11 @@ of algorithms in [18]_ [19]_. Other regularization ^^^^^^^^^^^^^^^^^^^^ -While entropic OT is the most common and favored in practice, there exist other -kind of regularization. We provide in POT two specific solvers for other -regularization terms, namely quadratic regularization and group lasso -regularization. But we also provide in :any:`ot.optim` two generic solvers that allows solving any -smooth regularization in practice. +While entropic OT is the most common and favored in practice, there exists other +kinds of regularizations. We provide in POT two specific solvers for other +regularization terms, namely quadratic regularization and group Lasso +regularization. But we also provide in :any:`ot.optim` two generic solvers +that allows solving any smooth regularization in practice. Quadratic regularization """""""""""""""""""""""" @@ -259,8 +430,8 @@ regularization of the form .. math:: \Omega(\gamma)=\sum_{i,j} \gamma_{i,j}^2 -this regularization term has a similar effect to entropic regularization in -densifying the OT matrix but it keeps some sort of sparsity that is lost with +This regularization term has an effect similar to entropic regularization by +densifying the OT matrix, yet it keeps some sort of sparsity that is lost with entropic regularization as soon as :math:`\lambda>0` [17]_. This problem can be solved with POT using solvers from :any:`ot.smooth`, more specifically functions :any:`ot.smooth.smooth_ot_dual` or @@ -278,30 +449,29 @@ choose the quadratic regularization. Group Lasso regularization """""""""""""""""""""""""" -Another regularization that has been used in recent years [5]_ is the group lasso +Another regularization that has been used in recent years [5]_ is the group Lasso regularization .. math:: \Omega(\gamma)=\sum_{j,G\in\mathcal{G}} \|\gamma_{G,j}\|_q^p -where :math:`\mathcal{G}` contains non overlapping groups of lines in the OT -matrix. This regularization proposed in [5]_ will promote sparsity at the group level and for +where :math:`\mathcal{G}` contains non-overlapping groups of lines in the OT +matrix. This regularization proposed in [5]_ promotes sparsity at the group level and for instance will force target samples to get mass from a small number of groups. Note that the exact OT solution is already sparse so this regularization does -not make sens if it is not combined with entropic regularization. Depending on +not make sense if it is not combined with entropic regularization. Depending on the choice of :code:`p` and :code:`q`, the problem can be solved with different -approaches. When :code:`q=1` and :code:`p<1` the problem is non convex but can +approaches. When :code:`q=1` and :code:`p<1` the problem is non-convex but can be solved using an efficient majoration minimization approach with :any:`ot.sinkhorn_lpl1_mm`. When :code:`q=2` and :code:`p=1` we recover the convex group lasso and we provide a solver using generalized conditional -gradient algorithm [7]_ in function -:any:`ot.da.sinkhorn_l1l2_gl`. +gradient algorithm [7]_ in function :any:`ot.da.sinkhorn_l1l2_gl`. .. hint:: - Examples of group Lasso regularization are available in : + Examples of group Lasso regularization are available in: - - :any:`auto_examples/plot_otda_classes` - - :any:`auto_examples/plot_otda_d2` + - :any:`auto_examples/domain-adaptation/plot_otda_classes` + - :any:`auto_examples/domain-adaptation/plot_otda_d2` Generic solvers @@ -322,11 +492,10 @@ you can use function :any:`ot.optim.cg` that will use a conditional gradient as proposed in [6]_ . You need to provide the regularization function as parameter ``f`` and its gradient as parameter ``df``. Note that the conditional gradient relies on iterative solving of a linearization of the problem using the exact -:any:`ot.emd` so it can be slow in practice. But, being an interior point -algorithm, it always returns a -transport matrix that does not violates the marginals. +:any:`ot.emd` so it can be quite slow in practice. However, being an interior point +algorithm, it always returns a transport matrix that does not violates the marginals. -Another generic solver is proposed to solve the problem +Another generic solver is proposed to solve the problem: .. math:: \gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j}+ \lambda_e\Omega_e(\gamma) + \lambda\Omega(\gamma) @@ -347,7 +516,7 @@ relies on :any:`ot.sinkhorn` for its iterations. Wasserstein Barycenters ----------------------- -A Wasserstein barycenter is a distribution that minimize its Wasserstein +A Wasserstein barycenter is a distribution that minimizes its Wasserstein distance with respect to other distributions [16]_. It corresponds to minimizing the following problem by searching a distribution :math:`\mu` such that @@ -378,18 +547,18 @@ be expressed as where :math:`b_k` are also weights in the simplex. In the non-regularized case, the problem above is a classical linear program. In this case we propose a -solver :any:`ot.lp.barycenter` that rely on generic LP solvers. By default the +solver :meth:`ot.lp.barycenter` that relies on generic LP solvers. By default the function uses :any:`scipy.optimize.linprog`, but more efficient LP solvers from cvxopt can be also used by changing parameter :code:`solver`. Note that this problem requires to solve a very large linear program and can be very slow in practice. Similarly to the OT problem, OT barycenters can be computed in the regularized -case. When using entropic regularization is used, the problem can be solved with a -generalization of the sinkhorn algorithm based on bregman projections [3]_. This +case. When entropic regularization is used, the problem can be solved with a +generalization of the Sinkhorn algorithm based on Bregman projections [3]_. This algorithm is provided in function :any:`ot.bregman.barycenter` also available as :any:`ot.barycenter`. In this case, the algorithm scales better to large -distributions and rely only on matrix multiplications that can be performed in +distributions and relies only on matrix multiplications that can be performed in parallel. In addition to the speedup brought by regularization, one can also greatly @@ -400,18 +569,18 @@ operators. We provide an implementation of this algorithm in function :any:`ot.bregman.convolutional_barycenter2d`. .. hint:: - Examples of Wasserstein (:any:`ot.lp.barycenter`) and regularized Wasserstein + Examples of Wasserstein (:meth:`ot.lp.barycenter`) and regularized Wasserstein barycenter (:any:`ot.bregman.barycenter`) computation are available in : - - :any:`auto_examples/plot_barycenter_1D` - - :any:`auto_examples/plot_barycenter_lp_vs_entropic` + - :any:`auto_examples/barycenters/plot_barycenter_1D` + - :any:`auto_examples/barycenters/plot_barycenter_lp_vs_entropic` An example of convolutional barycenter (:any:`ot.bregman.convolutional_barycenter2d`) computation for 2D images is available in : - - :any:`auto_examples/plot_convolutional_barycenter` + - :any:`auto_examples/barycenters/plot_convolutional_barycenter` @@ -419,7 +588,7 @@ Barycenters with free support ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Estimating the Wasserstein barycenter with free support but fixed weights -corresponds to solving the following optimization problem: +corresponds to solving the following optimization problem: .. math:: \min_{\{x_i\}} \quad \sum_{k} w_kW(\mu,\mu_k) @@ -436,7 +605,7 @@ return a locally optimal support :math:`\{x_i\}` for uniform or given weights An example of the free support barycenter estimation is available in : - - :any:`auto_examples/plot_free_support_barycenter` + - :any:`auto_examples/barycenters/plot_free_support_barycenter` @@ -444,7 +613,7 @@ return a locally optimal support :math:`\{x_i\}` for uniform or given weights Monge mapping and Domain adaptation ----------------------------------- -The original transport problem investigated by Gaspard Monge was seeking for a +The original transport problem investigated by Gaspard Monge was seeking for a mapping function that maps (or transports) between a source and target distribution but that minimizes the transport loss. The existence and uniqueness of this optimal mapping is still an open problem in the general case but has been proven @@ -462,24 +631,24 @@ approximate a Monge mapping from finite distributions. First note that when the source and target distributions are supposed to be Gaussian distributions, there exists a close form solution for the mapping and its an affine function [14]_ of the form :math:`T(x)=Ax+b` . In this case we provide the function -:any:`ot.da.OT_mapping_linear` that return the operator :math:`A` and vector +:any:`ot.da.OT_mapping_linear` that returns the operator :math:`A` and vector :math:`b`. Note that if the number of samples is too small there is a parameter -:code:`reg` that provide a regularization for the covariance matrix estimation. +:code:`reg` that provides a regularization for the covariance matrix estimation. For a more general mapping estimation we also provide the barycentric mapping -proposed in [6]_ . It is implemented in the class :any:`ot.da.EMDTransport` and -other transport based classes in :any:`ot.da` . Those classes are discussed more -in the following but follow an interface similar to sklearn classes. Finally a +proposed in [6]_. It is implemented in the class :any:`ot.da.EMDTransport` and +other transport-based classes in :any:`ot.da` . Those classes are discussed more +in the following but follow an interface similar to scikit-learn classes. Finally a method proposed in [8]_ that estimates a continuous mapping approximating the barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for -linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non linear mapping. +linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping. .. hint:: An example of the linear Monge mapping estimation is available in : - - :any:`auto_examples/plot_otda_linear_mapping` + - :any:`auto_examples/domain-adaptation/plot_otda_linear_mapping` Domain adaptation classes ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -491,21 +660,19 @@ transport labeled source samples onto the target distribution with no labels. We provide several classes based on :any:`ot.da.BaseTransport` that provide several OT and mapping estimations. The interface of those classes is similar to -classifiers in sklearn toolbox. At initialization, several parameters such as - regularization parameter value can be set. Then one needs to estimate the +classifiers in scikit-learn. At initialization, several parameters such as +regularization parameter value can be set. Then one needs to estimate the mapping with function :any:`ot.da.BaseTransport.fit`. Finally one can map the samples from source to target with :any:`ot.da.BaseTransport.transform` and from target to source with :any:`ot.da.BaseTransport.inverse_transform`. -Here is -an example for class :any:`ot.da.EMDTransport` : +Here is an example for class :any:`ot.da.EMDTransport`: .. code:: ot_emd = ot.da.EMDTransport() ot_emd.fit(Xs=Xs, Xt=Xt) - - Mapped_Xs= ot_emd.transform(Xs=Xs) + Xs_mapped = ot_emd.transform(Xs=Xs) A list of the provided implementation is given in the following note. @@ -514,24 +681,24 @@ A list of the provided implementation is given in the following note. Here is a list of the OT mapping classes inheriting from :any:`ot.da.BaseTransport` - * :any:`ot.da.EMDTransport` : Barycentric mapping with EMD transport - * :any:`ot.da.SinkhornTransport` : Barycentric mapping with Sinkhorn transport - * :any:`ot.da.SinkhornL1l2Transport` : Barycentric mapping with Sinkhorn + + * :any:`ot.da.EMDTransport`: Barycentric mapping with EMD transport + * :any:`ot.da.SinkhornTransport`: Barycentric mapping with Sinkhorn transport + * :any:`ot.da.SinkhornL1l2Transport`: Barycentric mapping with Sinkhorn + group Lasso regularization [5]_ - * :any:`ot.da.SinkhornLpl1Transport` : Barycentric mapping with Sinkhorn + + * :any:`ot.da.SinkhornLpl1Transport`: Barycentric mapping with Sinkhorn + non convex group Lasso regularization [5]_ - * :any:`ot.da.LinearTransport` : Linear mapping estimation between Gaussians + * :any:`ot.da.LinearTransport`: Linear mapping estimation between Gaussians [14]_ - * :any:`ot.da.MappingTransport` : Nonlinear mapping estimation [8]_ + * :any:`ot.da.MappingTransport`: Nonlinear mapping estimation [8]_ .. hint:: - Example of the use of OTDA classes are available in : + Examples of the use of OTDA classes are available in: - - :any:`auto_examples/plot_otda_color_images` - - :any:`auto_examples/plot_otda_mapping` - - :any:`auto_examples/plot_otda_mapping_colors_images` - - :any:`auto_examples/plot_otda_semi_supervised` + - :any:`auto_examples/domain-adaptation/plot_otda_color_images` + - :any:`auto_examples/domain-adaptation/plot_otda_mapping` + - :any:`auto_examples/domain-adaptation/plot_otda_mapping_colors_images` + - :any:`auto_examples/domain-adaptation/plot_otda_semi_supervised` Other applications ------------------ @@ -545,14 +712,14 @@ Wasserstein Discriminant Analysis Wasserstein Discriminant Analysis [11]_ is a generalization of `Fisher Linear Discriminant Analysis `__ that allows discrimination between classes that are not linearly separable. It -consist in finding a linear projector optimizing the following criterion +consists in finding a linear projector optimizing the following criterion .. math:: P = \text{arg}\min_P \frac{\sum_i OT_e(\mu_i\#P,\mu_i\#P)}{\sum_{i,j\neq i} OT_e(\mu_i\#P,\mu_j\#P)} where :math:`\#` is the push-forward operator, :math:`OT_e` is the entropic OT -loss and :math:`\mu_i` is the +loss and :math:`\mu_i` is the distribution of samples from class :math:`i`. :math:`P` is also constrained to be in the Stiefel manifold. WDA can be solved in POT using function :any:`ot.dr.wda`. It requires to have installed :code:`pymanopt` and @@ -561,6 +728,7 @@ respectively. Note that we also provide the Fisher discriminant estimator in :any:`ot.dr.fda` for easy comparison. .. warning:: + Note that due to the hard dependency on :code:`pymanopt` and :code:`autograd`, :any:`ot.dr` is not imported by default. If you want to use it you have to specifically import it with :code:`import ot.dr` . @@ -569,7 +737,7 @@ respectively. Note that we also provide the Fisher discriminant estimator in An example of the use of WDA is available in : - - :any:`auto_examples/plot_WDA` + - :any:`auto_examples/others/plot_WDA` Unbalanced optimal transport @@ -610,7 +778,7 @@ linear term. Examples of the use of :any:`ot.sinkhorn_unbalanced` are available in : - - :any:`auto_examples/plot_UOT_1D` + - :any:`auto_examples/unbalanced-partial/plot_UOT_1D` Unbalanced Barycenters @@ -622,17 +790,17 @@ histograms with different masses as a Fréchet Mean: .. math:: \min_{\mu} \quad \sum_{k} w_kW_u(\mu,\mu_k) -Where :math:`W_u` is the unbalanced Wasserstein metric defined above. This problem +where :math:`W_u` is the unbalanced Wasserstein metric defined above. This problem can also be solved using generalized version of Sinkhorn's algorithm and it is implemented the main function :any:`ot.barycenter_unbalanced`. .. note:: The main function to compute UOT barycenters is :any:`ot.barycenter_unbalanced`. - This function is a wrapper and the parameter :code:`method` help you select + This function is a wrapper and the parameter :code:`method` helps you select the actual algorithm used to solve the problem: - + :code:`method='sinkhorn'` calls :any:`ot.unbalanced.barycenter_unbalanced_sinkhorn_unbalanced` + + :code:`method='sinkhorn'` calls :meth:`ot.unbalanced.barycenter_unbalanced_sinkhorn_unbalanced` the generalized Sinkhorn algorithm [10]_. + :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.barycenter_unbalanced_stabilized` the log stabilized version of the algorithm [10]_. @@ -642,7 +810,7 @@ implemented the main function :any:`ot.barycenter_unbalanced`. Examples of the use of :any:`ot.barycenter_unbalanced` are available in : - - :any:`auto_examples/plot_UOT_barycenter_1D` + - :any:`auto_examples/unbalanced-partial/plot_UOT_barycenter_1D` Partial optimal transport @@ -686,9 +854,9 @@ regularization of the problem. .. hint:: - Examples of the use of :any:`ot.partial` are available in : + Examples of the use of :any:`ot.partial` are available in: - - :any:`auto_examples/plot_partial` + - :any:`auto_examples/unbalanced-partial/plot_partial_wass_and_gromov` @@ -699,7 +867,7 @@ Gromov Wasserstein (GW) is a generalization of OT to distributions that do not l the same space [13]_. In this case one cannot compute distance between samples from the two distributions. [13]_ proposed instead to realign the metric spaces by computing a transport between distance matrices. The Gromow Wasserstein -alignement between two distributions can be expressed as the one minimizing: +alignment between two distributions can be expressed as the one minimizing: .. math:: GW = \min_\gamma \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*\gamma_{i,j}*\gamma_{k,l} @@ -731,8 +899,8 @@ positive matrix. We provide a block coordinate optimization procedure in barycenters respectively. Finally note that recently a fusion between Wasserstein and GW, coined Fused -Gromov-Wasserstein (FGW) has been proposed -in [24]_. It allows to compute a similarity between objects that are only partly in +Gromov-Wasserstein (FGW) has been proposed [24]_. +It allows to compute a similarity between objects that are only partly in the same space. As such it can be used to measure similarity between labeled graphs for instance and also provide computable barycenters. The implementations of FGW and FGW barycenter is provided in functions @@ -740,15 +908,15 @@ The implementations of FGW and FGW barycenter is provided in functions .. hint:: - Examples of computation of GW, regularized G and FGW are available in : + Examples of computation of GW, regularized G and FGW are available in: - - :any:`auto_examples/plot_gromov` - - :any:`auto_examples/plot_fgw` + - :any:`auto_examples/gromov/plot_gromov` + - :any:`auto_examples/gromov/plot_fgw` - Examples of GW, regularized GW and FGW barycenters are available in : + Examples of GW, regularized GW and FGW barycenters are available in: - - :any:`auto_examples/plot_gromov_barycenter` - - :any:`auto_examples/plot_barycenter_fgw` + - :any:`auto_examples/gromov/plot_gromov_barycenter` + - :any:`auto_examples/gromov/plot_barycenter_fgw` GPU acceleration @@ -764,20 +932,20 @@ implementations use the :code:`cupy` toolbox that obviously need to be installed algebra) have been implemented in :any:`ot.gpu`. Here is a short list on the main entries: - - :any:`ot.gpu.dist` : computation of distance matrix - - :any:`ot.gpu.sinkhorn` : computation of sinkhorn - - :any:`ot.gpu.sinkhorn_lpl1_mm` : computation of sinkhorn + group lasso + - :meth:`ot.gpu.dist`: computation of distance matrix + - :meth:`ot.gpu.sinkhorn`: computation of sinkhorn + - :meth:`ot.gpu.sinkhorn_lpl1_mm`: computation of sinkhorn + group lasso Note that while the :any:`ot.gpu` module has been designed to be compatible with -POT, calling its function with :any:`numpy` arrays will incur a large overhead due to +POT, calling its function with :any:`numpy` arrays will incur a large overhead due to the memory copy of the array on GPU prior to computation and conversion of the array after computation. To avoid this overhead, we provide functions -:any:`ot.gpu.to_gpu` and :any:`ot.gpu.to_np` that perform the conversion +:meth:`ot.gpu.to_gpu` and :meth:`ot.gpu.to_np` that perform the conversion explicitly. - .. warning:: - Note that due to the hard dependency on :code:`cupy`, :any:`ot.gpu` is not + + Note that due to the hard dependency on :code:`cupy`, :any:`ot.gpu` is not imported by default. If you want to use it you have to specifically import it with :code:`import ot.gpu` . @@ -785,8 +953,6 @@ explicitly. FAQ --- - - 1. **How to solve a discrete optimal transport problem ?** The solver for discrete OT is the function :py:mod:`ot.emd` that returns @@ -798,10 +964,10 @@ FAQ .. code:: python - # a,b are 1D histograms (sum to 1 and positive) + # a and b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix - T=ot.emd(a,b,M) # exact linear program - T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT + T = ot.emd(a, b, M) # exact linear program + T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT More detailed examples can be seen on this example: :doc:`auto_examples/plot_OT_2D_samples` @@ -823,15 +989,15 @@ FAQ 3. **Why is Sinkhorn slower than EMD ?** This might come from the choice of the regularization term. The speed of - convergence of sinkhorn depends directly on this term [22]_ and when the - regularization gets very small the problem try and approximate the exact OT + convergence of Sinkhorn depends directly on this term [22]_. When the + regularization gets very small the problem tries to approximate the exact OT which leads to slow convergence in addition to numerical problems. In other - words, for large regularization sinkhorn will be very fast to converge, for + words, for large regularization Sinkhorn will be very fast to converge, for small regularization (when you need an OT matrix close to the true OT), it might be quicker to use the EMD solver. - Also note that the numpy implementation of the sinkhorn can use parallel - computation depending on the configuration of your system but very important + Also note that the numpy implementation of Sinkhorn can use parallel + computation depending on the configuration of your system, yet very important speedup can be obtained by using a GPU implementation since all operations are matrix/vector products. @@ -863,11 +1029,6 @@ References problems `__. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. -.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, - `Supervised planetary unmixing with optimal - transport `__, - Whorkshop on Hyperspectral Image and Signal Processing : Evolution in - Remote Sensing (WHISPERS), 2016. .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, `Optimal Transport for Domain Adaptation `__, in IEEE @@ -955,7 +1116,7 @@ References iteration `__, Advances in Neural Information Processing Systems (NIPS) 31 -.. [23] Aude, G., Peyré, G., Cuturi, M., `Learning Generative Models with +.. [23] Genevay, A., Peyré, G., Cuturi, M., `Learning Generative Models with Sinkhorn Divergences `__, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 @@ -972,11 +1133,6 @@ References .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport , Advances in Neural Information Processing Systems 33 (NeurIPS). - -.. [27] Redko I., Courty N., Flamary R., Tuia D. (2019). Optimal Transport for Multi-source - Domain Adaptation under Target Shift , - Proceedings of the Twenty-Second International Conference on Artificial Intelligence - and Statistics (AISTATS) 22, 2019. .. [28] Caffarelli, L. A., McCann, R. J. (2020). Free boundaries in optimal transport and Monge-Ampere obstacle problems , @@ -985,3 +1141,7 @@ References .. [29] Chapel, L., Alaya, M., Gasso, G. (2019). Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning , arXiv preprint arXiv:2002.08276. + +.. [30] Flamary, Rémi, et al. "Optimal transport with Laplacian regularization: + Applications to domain adaptation and shape matching." NIPS Workshop on Optimal + Transport and Machine Learning OTML. 2014. diff --git a/examples/barycenters/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 27ddc8e..2d68a39 100644 --- a/examples/barycenters/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -==================================================== +======================================================== 2D free support Wasserstein barycenters of distributions -==================================================== +======================================================== Illustration of 2D Wasserstein barycenters if distributions are weighted sum of diracs. diff --git a/examples/domain-adaptation/plot_otda_jcpot.py b/examples/domain-adaptation/plot_otda_jcpot.py index c495690..0d974f4 100644 --- a/examples/domain-adaptation/plot_otda_jcpot.py +++ b/examples/domain-adaptation/plot_otda_jcpot.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -======================== +================================ OT for multi-source target shift -======================== +================================ This example introduces a target shift problem with two 2D source and 1 target domain. diff --git a/examples/gromov/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py index 3f81765..556e08f 100644 --- a/examples/gromov/plot_barycenter_fgw.py +++ b/examples/gromov/plot_barycenter_fgw.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ================================= -Plot graphs' barycenter using FGW +Plot graphs barycenter using FGW ================================= This example illustrates the computation barycenter of labeled graphs using diff --git a/examples/gromov/plot_fgw.py b/examples/gromov/plot_fgw.py index 97fe619..5475fb3 100644 --- a/examples/gromov/plot_fgw.py +++ b/examples/gromov/plot_fgw.py @@ -26,7 +26,7 @@ from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein ############################################################################## # Generate data -# --------- +# ------------- #%% parameters # We create two 1D random measures @@ -76,7 +76,7 @@ pl.show() ############################################################################## # Create structure matrices and across-feature distance matrix -# --------- +# ------------------------------------------------------------ #%% Structure matrices and across-features distance matrix C1 = ot.dist(xs) @@ -88,7 +88,7 @@ Got = ot.emd([], [], M) ############################################################################## # Plot matrices -# --------- +# ------------- #%% cmap = 'Reds' @@ -131,7 +131,7 @@ pl.show() ############################################################################## # Compute FGW/GW -# --------- +# -------------- #%% Computing FGW and GW alpha = 1e-3 @@ -145,7 +145,7 @@ Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, ############################################################################## # Visualize transport matrices -# --------- +# ---------------------------- #%% visu OT matrix cmap = 'Blues' diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index 75cd295..b07f99f 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -87,7 +87,7 @@ pl.show() ############################################################################## # Solve Smooth OT -# -------------- +# --------------- #%% Smooth OT with KL regularization diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index 1544e82..af1bc12 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -107,7 +107,7 @@ pl.show() ############################################################################## # Emprirical Sinkhorn -# ---------------- +# ------------------- #%% sinkhorn diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index f3deeff..27df656 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -4,9 +4,11 @@ 2D Sliced Wasserstein Distance ============================== -This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31]. +This example illustrates the computation of the sliced Wasserstein Distance as +proposed in [31]. -[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of +measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 """ @@ -50,9 +52,9 @@ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') pl.legend(loc=0) pl.title('Source and target distributions') -################################################################################### -# Compute Sliced Wasserstein distance for different seeds and number of projections -# ----------- +############################################################################### +# Sliced Wasserstein distance for different seeds and number of projections +# ------------------------------------------------------------------------- n_seed = 50 n_projections_arr = np.logspace(0, 3, 25, dtype=int) @@ -66,9 +68,9 @@ for seed in range(n_seed): res_mean = np.mean(res, axis=0) res_std = np.std(res, axis=0) -################################################################################### +############################################################################### # Plot Sliced Wasserstein Distance -# ----------- +# -------------------------------- pl.figure(2) pl.plot(n_projections_arr, res_mean, label="SWD") diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 2ea8b05..183849c 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -61,8 +61,7 @@ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') ############################################################################## # Solve Unbalanced Sinkhorn -# -------------- - +# ------------------------- # Sinkhorn diff --git a/ot/__init__.py b/ot/__init__.py index ec3ede2..0116d33 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -37,7 +37,8 @@ from . import partial # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 +from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, + sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance @@ -46,9 +47,10 @@ from .utils import dist, unif, tic, toc, toq __version__ = "0.7.0" -__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', - 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', +__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', + 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance'] + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'smooth', 'stochastic', 'unbalanced', 'partial'] diff --git a/ot/bregman.py b/ot/bregman.py index f1f8437..dcd35e1 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -67,6 +67,21 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. + Returns ------- @@ -175,6 +190,21 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. + Returns ------- W : (n_hists) ndarray or float -- cgit v1.2.3 From 184f8f4f7ac78f1dd7f653496d2753211a4e3426 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 1 Jun 2021 10:10:54 +0200 Subject: [MRG] POT numpy/torch/jax backends (#249) * add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update ot/utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend Co-authored-by: Nicolas Courty Co-authored-by: Alexandre Gramfort --- .github/requirements_test_windows.txt | 10 + .github/workflows/build_tests.yml | 9 +- README.md | 8 +- docs/source/quickstart.rst | 68 +++- docs/source/readme.rst | 70 ++-- examples/README.txt | 2 +- examples/backends/README.txt | 4 + examples/backends/plot_unmix_optim_torch.py | 161 +++++++++ ot/__init__.py | 1 + ot/backend.py | 536 ++++++++++++++++++++++++++++ ot/bregman.py | 141 ++++---- ot/gpu/__init__.py | 4 +- ot/lp/__init__.py | 137 ++++--- ot/utils.py | 128 +++++-- requirements.txt | 3 + test/test_backend.py | 364 +++++++++++++++++++ test/test_bregman.py | 74 ++++ test/test_gromov.py | 10 +- test/test_ot.py | 91 ++++- test/test_partial.py | 4 +- test/test_utils.py | 76 +++- 21 files changed, 1692 insertions(+), 209 deletions(-) create mode 100644 .github/requirements_test_windows.txt create mode 100644 examples/backends/README.txt create mode 100644 examples/backends/plot_unmix_optim_torch.py create mode 100644 ot/backend.py create mode 100644 test/test_backend.py (limited to 'examples') diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt new file mode 100644 index 0000000..331dd57 --- /dev/null +++ b/.github/requirements_test_windows.txt @@ -0,0 +1,10 @@ +numpy +scipy>=1.3 +cython +matplotlib +autograd +pymanopt==0.2.4; python_version <'3' +pymanopt; python_version >= '3' +cvxopt +scikit-learn +pytest \ No newline at end of file diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 2fc6770..92a07b5 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -40,7 +40,7 @@ jobs: pip install -e . - name: Run tests run: | - python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot + python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes - name: Upload codecov run: | codecov @@ -142,11 +142,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install pytest "pytest-cov<2.6" + python -m pip install -r .github/requirements_test_windows.txt + python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -m pip install pytest "pytest-cov<2.6" - name: Install POT run: | - pip install -e . + python -m pip install -e . - name: Run tests run: | python -m pytest -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot diff --git a/README.md b/README.md index f5d18c1..e5e16e0 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ POT provides the following generic OT solvers (links to examples): * [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. -* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html) with optional GPU implementation (requires cupy). +* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. @@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. +* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays. POT provides the following Machine Learning related solvers: @@ -77,8 +78,7 @@ The library has been tested on Linux, MacOSX and Windows. It requires a C++ comp - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) -- Matplotlib (>=1.5) +- Cython (>=0.23) (build only, not necessary when installing wheels from pip or conda) #### Pip installation @@ -129,7 +129,7 @@ Some sub-modules require additional dependences which are discussed below pip install pymanopt autograd ``` -* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. +* **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on [this page](https://docs-cupy.chainer.org/en/stable/install.html). Obviously you will need CUDA installed and a compatible GPU. Note that this module is deprecated since version 0.8 and will be deleted in the future. GPU is now handled automatically through the backends and several solver already can run on GPU using the Pytorch backend. ## Examples diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index cf5d6aa..fd046a1 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -15,6 +15,12 @@ are also available as notebooks on the POT Github. in ML applications we refer the reader to the following `OTML tutorial `_. +.. note:: + + Since version 0.8, POT provides a backend to automatically solve some OT + problems independently from the toolbox used by the user (numpy/torch/jax). + We provide a discussion about which functions are compatible in section + `Backend section <#solving-ot-with-multiple-backends>`_ . Why Optimal Transport ? @@ -158,7 +164,6 @@ Wasserstein but has better computational and `statistical properties `_. - Optimal transport and Wasserstein distance ------------------------------------------ @@ -922,6 +927,13 @@ The implementations of FGW and FGW barycenter is provided in functions GPU acceleration ^^^^^^^^^^^^^^^^ +.. warning:: + + The :any:`ot.gpu` has been deprecated since the release 0.8 of POT and + should not be used. The GPU implementation (in Pytorch for instance) can be + used with the novel backends using the compatible functions from POT. + + We provide several implementation of our OT solvers in :any:`ot.gpu`. Those implementations use the :code:`cupy` toolbox that obviously need to be installed. @@ -950,6 +962,60 @@ explicitly. use it you have to specifically import it with :code:`import ot.gpu` . +Solving OT with Multiple backends +--------------------------------- + +.. _backends_section: + +Since version 0.8, POT provides a backend that allows to code solvers +independently from the type of the input arrays. The idea is to provide the user +with a package that works seamlessly and returns a solution for instance as a +Pytorch tensors when the function has Pytorch tensors as input. + + +How it works +^^^^^^^^^^^^ + +The aim of the backend is to use the same function independently of the type of +the input arrays. + +For instance when executing the following code + +.. code:: python + + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + T = ot.emd(a, b, M) # exact linear program + w = ot.emd2(a, b, M) # Wasserstein computation + +the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type +:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of +the function will be the same type as the inputs and on the same device. When +possible all computations are done on the same device and also when possible the +output will be differentiable with respect to the input of the function. + + + +List of compatible Backends +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +- `Numpy `_ (all functions and solvers) +- `Pytorch `_ (all outputs differentiable w.r.t. inputs) +- `Jax `_ (Some functions are differentiable some require a wrapper) + +List of compatible functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +This list will get longer for new releases and will hopefully disappear when POT +become fully implemented with the backend. + +- :any:`ot.emd` +- :any:`ot.emd2` +- :any:`ot.sinkhorn` +- :any:`ot.sinkhorn2` +- :any:`ot.dist` + + FAQ --- diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 3b594c2..82d3e6c 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -26,8 +26,7 @@ POT provides the following generic OT solvers (links to examples): Algorithm `__ [2] , stabilized version [9] [10], greedy Sinkhorn [22] and `Screening Sinkhorn - [26] `__ - with optional GPU implementation (requires cupy). + [26] `__. - Bregman projections for `Wasserstein barycenter `__ [3], `convolutional @@ -69,6 +68,11 @@ POT provides the following generic OT solvers (links to examples): - `Sliced Wasserstein `__ [31, 32]. +- `Several + backends `__ + for easy use of POT with + `Pytorch `__/`jax `__/`Numpy `__ + arrays. POT provides the following Machine Learning related solvers: @@ -104,12 +108,14 @@ paper `__: :: - Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer;, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer, + POT Python Optimal Transport library, + Journal of Machine Learning Research, 22(78):1−8, 2021. Website: https://pythonot.github.io/ In Bibtex format: -:: +.. code:: bibtex @article{flamary2021pot, author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, @@ -131,8 +137,8 @@ following Python modules: - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) -- Matplotlib (>=1.5) +- Cython (>=0.23) (build only, not necessary when installing wheels + from pip or conda) Pip installation ^^^^^^^^^^^^^^^^ @@ -140,19 +146,19 @@ Pip installation Note that due to a limitation of pip, ``cython`` and ``numpy`` need to be installed prior to installing POT. This can be done easily with -:: +.. code:: console pip install numpy cython You can install the toolbox through PyPI with: -:: +.. code:: console pip install POT or get the very latest version by running: -:: +.. code:: console pip install -U https://github.com/PythonOT/POT/archive/master.zip # with --user for user install (no root) @@ -163,7 +169,7 @@ If you use the Anaconda python distribution, POT is available in `conda-forge `__. To install it and the required dependencies: -:: +.. code:: console conda install -c conda-forge pot @@ -188,15 +194,17 @@ below - **ot.dr** (Wasserstein dimensionality reduction) depends on autograd and pymanopt that can be installed with: - :: +.. code:: shell - pip install pymanopt autograd + pip install pymanopt autograd - **ot.gpu** (GPU accelerated OT) depends on cupy that have to be installed following instructions on `this page `__. - -obviously you need CUDA installed and a compatible GPU. + Obviously you will need CUDA installed and a compatible GPU. Note + that this module is deprecated since version 0.8 and will be deleted + in the future. GPU is now handled automatically through the backends + and several solver already can run on GPU using the Pytorch backend. Examples -------- @@ -206,36 +214,36 @@ Short examples - Import the toolbox - .. code:: python +.. code:: python - import ot + import ot - Compute Wasserstein distances - .. code:: python +.. code:: python - # a,b are 1D histograms (sum to 1 and positive) - # M is the ground cost matrix - Wd=ot.emd2(a,b,M) # exact linear program - Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT - # if b is a matrix compute all distances to a and return a vector + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + Wd = ot.emd2(a, b, M) # exact linear program + Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT + # if b is a matrix compute all distances to a and return a vector - Compute OT matrix - .. code:: python +.. code:: python - # a,b are 1D histograms (sum to 1 and positive) - # M is the ground cost matrix - T=ot.emd(a,b,M) # exact linear program - T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT + # a and b are 1D histograms (sum to 1 and positive) + # M is the ground cost matrix + T = ot.emd(a, b, M) # exact linear program + T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT - Compute Wasserstein barycenter - .. code:: python +.. code:: python - # A is a n*d matrix containing d 1D histograms - # M is the ground cost matrix - ba=ot.barycenter(A,M,reg) # reg is regularization parameter + # A is a n*d matrix containing d 1D histograms + # M is the ground cost matrix + ba = ot.barycenter(A, M, reg) # reg is regularization parameter Examples and Notebooks ~~~~~~~~~~~~~~~~~~~~~~ diff --git a/examples/README.txt b/examples/README.txt index 69a9f84..b48487f 100644 --- a/examples/README.txt +++ b/examples/README.txt @@ -1,7 +1,7 @@ Examples gallery ================ -This is a gallery of all the POT example files. +This is a gallery of all the POT example files. OT and regularized OT diff --git a/examples/backends/README.txt b/examples/backends/README.txt new file mode 100644 index 0000000..3ee0e27 --- /dev/null +++ b/examples/backends/README.txt @@ -0,0 +1,4 @@ + + +POT backend examples +-------------------- \ No newline at end of file diff --git a/examples/backends/plot_unmix_optim_torch.py b/examples/backends/plot_unmix_optim_torch.py new file mode 100644 index 0000000..9ae66e9 --- /dev/null +++ b/examples/backends/plot_unmix_optim_torch.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +r""" +================================= +Wasserstein unmixing with PyTorch +================================= + +In this example we estimate mixing parameters from distributions that minimize +the Wasserstein distance. In other words we suppose that a target +distribution :math:`\mu^t` can be expressed as a weighted sum of source +distributions :math:`\mu^s_k` with the following model: + +.. math:: + \mu^t = \sum_{k=1}^K w_k\mu^s_k + +where :math:`\mathbf{w}` is a vector of size :math:`K` and belongs in the +distribution simplex :math:`\Delta_K`. + +In order to estimate this weight vector we propose to optimize the Wasserstein +distance between the model and the observed :math:`\mu^t` with respect to +the vector. This leads to the following optimization problem: + +.. math:: + \min_{\mathbf{w}\in\Delta_K} \quad W \left(\mu^t,\sum_{k=1}^K w_k\mu^s_k\right) + +This minimization is done in this example with a simple projected gradient +descent in PyTorch. We use the automatic backend of POT that allows us to +compute the Wasserstein distance with :any:`ot.emd2` with +differentiable losses. + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot +import torch + + +############################################################################## +# Generate data +# ------------- + +#%% Data + +nt = 100 +nt1 = 10 # + +ns1 = 50 +ns = 2 * ns1 + +rng = np.random.RandomState(2) + +xt = rng.randn(nt, 2) * 0.2 +xt[:nt1, 0] += 1 +xt[nt1:, 1] += 1 + + +xs1 = rng.randn(ns1, 2) * 0.2 +xs1[:, 0] += 1 +xs2 = rng.randn(ns1, 2) * 0.2 +xs2[:, 1] += 1 + +xs = np.concatenate((xs1, xs2)) + +# Sample reweighting matrix H +H = np.zeros((ns, 2)) +H[:ns1, 0] = 1 / ns1 +H[ns1:, 1] = 1 / ns1 +# each columns sums to 1 and has weights only for samples form the +# corresponding source distribution + +M = ot.dist(xs, xt) + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + +pl.figure(1) +pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) +pl.scatter(xs1[:, 0], xs1[:, 1], label='Source $\mu^s_1$', alpha=0.5) +pl.scatter(xs2[:, 0], xs2[:, 1], label='Source $\mu^s_2$', alpha=0.5) +pl.title('Sources and Target distributions') +pl.legend() + + +############################################################################## +# Optimization of the model wrt the Wasserstein distance +# ------------------------------------------------------ + + +#%% Weights optimization with gradient descent + +# convert numpy arrays to torch tensors +H2 = torch.tensor(H) +M2 = torch.tensor(M) + +# weights for the source distributions +w = torch.tensor(ot.unif(2), requires_grad=True) + +# uniform weights for target +b = torch.tensor(ot.unif(nt)) + +lr = 2e-3 # learning rate +niter = 500 # number of iterations +losses = [] # loss along the iterations + +# loss for the minimal Wasserstein estimator + + +def get_loss(w): + a = torch.mv(H2, w) # distribution reweighting + return ot.emd2(a, b, M2) # squared Wasserstein 2 + + +for i in range(niter): + + loss = get_loss(w) + losses.append(float(loss)) + + loss.backward() + + with torch.no_grad(): + w -= lr * w.grad # gradient step + w[:] = ot.utils.proj_simplex(w) # projection on the simplex + + w.grad.zero_() + + +############################################################################## +# Estimated weights and convergence of the objective +# --------------------------------------------------- + +we = w.detach().numpy() +print('Estimated mixture:', we) + +pl.figure(2) +pl.semilogy(losses) +pl.grid() +pl.title('Wasserstein distance') +pl.xlabel("Iterations") + +############################################################################## +# Ploting the reweighted source distribution +# ------------------------------------------ + +pl.figure(3) + +# compute source weights +ws = H.dot(we) + +pl.scatter(xt[:, 0], xt[:, 1], label='Target $\mu^t$', alpha=0.5) +pl.scatter(xs[:, 0], xs[:, 1], color='C3', s=ws * 20 * ns, label='Weighted sources $\sum_{k} w_k\mu^s_k$', alpha=0.5) +pl.title('Target and reweighted source distributions') +pl.legend() diff --git a/ot/__init__.py b/ot/__init__.py index 5a8a415..3b072c6 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -33,6 +33,7 @@ from . import smooth from . import stochastic from . import unbalanced from . import partial +from . import backend # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d diff --git a/ot/backend.py b/ot/backend.py new file mode 100644 index 0000000..d68f5cf --- /dev/null +++ b/ot/backend.py @@ -0,0 +1,536 @@ +# -*- coding: utf-8 -*- +""" +Multi-lib backend for POT +""" + +# Author: Remi Flamary +# Nicolas Courty +# +# License: MIT License + +import numpy as np + +try: + import torch + torch_type = torch.Tensor +except ImportError: + torch = False + torch_type = float + +try: + import jax + import jax.numpy as jnp + jax_type = jax.numpy.ndarray +except ImportError: + jax = False + jax_type = float + +str_type_error = "All array should be from the same type/backend. Current types are : {}" + + +def get_backend_list(): + """ returns the list of available backends)""" + lst = [NumpyBackend(), ] + + if torch: + lst.append(TorchBackend()) + + if jax: + lst.append(JaxBackend()) + + return lst + + +def get_backend(*args): + """returns the proper backend for a list of input arrays + + Also raises TypeError if all arrays are not from the same backend + """ + # check that some arrays given + if not len(args) > 0: + raise ValueError(" The function takes at least one parameter") + # check all same type + + if isinstance(args[0], np.ndarray): + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + return NumpyBackend() + elif torch and isinstance(args[0], torch_type): + if not len(set(type(a) for a in args)) == 1: + raise ValueError(str_type_error.format([type(a) for a in args])) + return TorchBackend() + elif isinstance(args[0], jax_type): + return JaxBackend() + else: + raise ValueError("Unknown type of non implemented backend.") + + +def to_numpy(*args): + """returns numpy arrays from any compatible backend""" + + if len(args) == 1: + return get_backend(args[0]).to_numpy(args[0]) + else: + return [get_backend(a).to_numpy(a) for a in args] + + +class Backend(): + + __name__ = None + __type__ = None + + def __str__(self): + return self.__name__ + + # convert to numpy + def to_numpy(self, a): + raise NotImplementedError() + + # convert from numpy + def from_numpy(self, a, type_as=None): + raise NotImplementedError() + + def set_gradients(self, val, inputs, grads): + """ define the gradients for the value val wrt the inputs """ + raise NotImplementedError() + + def zeros(self, shape, type_as=None): + raise NotImplementedError() + + def ones(self, shape, type_as=None): + raise NotImplementedError() + + def arange(self, stop, start=0, step=1, type_as=None): + raise NotImplementedError() + + def full(self, shape, fill_value, type_as=None): + raise NotImplementedError() + + def eye(self, N, M=None, type_as=None): + raise NotImplementedError() + + def sum(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def cumsum(self, a, axis=None): + raise NotImplementedError() + + def max(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def min(self, a, axis=None, keepdims=False): + raise NotImplementedError() + + def maximum(self, a, b): + raise NotImplementedError() + + def minimum(self, a, b): + raise NotImplementedError() + + def dot(self, a, b): + raise NotImplementedError() + + def abs(self, a): + raise NotImplementedError() + + def exp(self, a): + raise NotImplementedError() + + def log(self, a): + raise NotImplementedError() + + def sqrt(self, a): + raise NotImplementedError() + + def norm(self, a): + raise NotImplementedError() + + def any(self, a): + raise NotImplementedError() + + def isnan(self, a): + raise NotImplementedError() + + def isinf(self, a): + raise NotImplementedError() + + def einsum(self, subscripts, *operands): + raise NotImplementedError() + + def sort(self, a, axis=-1): + raise NotImplementedError() + + def argsort(self, a, axis=None): + raise NotImplementedError() + + def flip(self, a, axis=None): + raise NotImplementedError() + + +class NumpyBackend(Backend): + + __name__ = 'numpy' + __type__ = np.ndarray + + def to_numpy(self, a): + return a + + def from_numpy(self, a, type_as=None): + if type_as is None: + return a + elif isinstance(a, float): + return a + else: + return a.astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # no gradients for numpy + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return np.zeros(shape) + else: + return np.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return np.ones(shape) + else: + return np.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return np.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return np.full(shape, fill_value) + else: + return np.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return np.eye(N, M) + else: + return np.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return np.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return np.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return np.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return np.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return np.maximum(a, b) + + def minimum(self, a, b): + return np.minimum(a, b) + + def dot(self, a, b): + return np.dot(a, b) + + def abs(self, a): + return np.abs(a) + + def exp(self, a): + return np.exp(a) + + def log(self, a): + return np.log(a) + + def sqrt(self, a): + return np.sqrt(a) + + def norm(self, a): + return np.sqrt(np.sum(np.square(a))) + + def any(self, a): + return np.any(a) + + def isnan(self, a): + return np.isnan(a) + + def isinf(self, a): + return np.isinf(a) + + def einsum(self, subscripts, *operands): + return np.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return np.sort(a, axis) + + def argsort(self, a, axis=-1): + return np.argsort(a, axis) + + def flip(self, a, axis=None): + return np.flip(a, axis) + + +class JaxBackend(Backend): + + __name__ = 'jax' + __type__ = jax_type + + def to_numpy(self, a): + return np.array(a) + + def from_numpy(self, a, type_as=None): + if type_as is None: + return jnp.array(a) + else: + return jnp.array(a).astype(type_as.dtype) + + def set_gradients(self, val, inputs, grads): + # no gradients for jax because it is functional + + # does not work + # from jax import custom_jvp + # @custom_jvp + # def f(*inputs): + # return val + # f.defjvps(*grads) + # return f(*inputs) + + return val + + def zeros(self, shape, type_as=None): + if type_as is None: + return jnp.zeros(shape) + else: + return jnp.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return jnp.ones(shape) + else: + return jnp.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return jnp.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return jnp.full(shape, fill_value) + else: + return jnp.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return jnp.eye(N, M) + else: + return jnp.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return jnp.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return jnp.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return jnp.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return jnp.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return jnp.maximum(a, b) + + def minimum(self, a, b): + return jnp.minimum(a, b) + + def dot(self, a, b): + return jnp.dot(a, b) + + def abs(self, a): + return jnp.abs(a) + + def exp(self, a): + return jnp.exp(a) + + def log(self, a): + return jnp.log(a) + + def sqrt(self, a): + return jnp.sqrt(a) + + def norm(self, a): + return jnp.sqrt(jnp.sum(jnp.square(a))) + + def any(self, a): + return jnp.any(a) + + def isnan(self, a): + return jnp.isnan(a) + + def isinf(self, a): + return jnp.isinf(a) + + def einsum(self, subscripts, *operands): + return jnp.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return jnp.sort(a, axis) + + def argsort(self, a, axis=-1): + return jnp.argsort(a, axis) + + def flip(self, a, axis=None): + return jnp.flip(a, axis) + + +class TorchBackend(Backend): + + __name__ = 'torch' + __type__ = torch_type + + def to_numpy(self, a): + return a.cpu().detach().numpy() + + def from_numpy(self, a, type_as=None): + if type_as is None: + return torch.from_numpy(a) + else: + return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) + + def set_gradients(self, val, inputs, grads): + from torch.autograd import Function + + # define a function that takes inputs and return val + class ValFunction(Function): + @staticmethod + def forward(ctx, *inputs): + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return grads + + return ValFunction.apply(*inputs) + + def zeros(self, shape, type_as=None): + if type_as is None: + return torch.zeros(shape) + else: + return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) + + def ones(self, shape, type_as=None): + if type_as is None: + return torch.ones(shape) + else: + return torch.ones(shape, dtype=type_as.dtype, device=type_as.device) + + def arange(self, stop, start=0, step=1, type_as=None): + if type_as is None: + return torch.arange(start, stop, step) + else: + return torch.arange(start, stop, step, device=type_as.device) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return torch.full(shape, fill_value) + else: + return torch.full(shape, fill_value, dtype=type_as.dtype, device=type_as.device) + + def eye(self, N, M=None, type_as=None): + if M is None: + M = N + if type_as is None: + return torch.eye(N, m=M) + else: + return torch.eye(N, m=M, dtype=type_as.dtype, device=type_as.device) + + def sum(self, a, axis=None, keepdims=False): + if axis is None: + return torch.sum(a) + else: + return torch.sum(a, axis, keepdim=keepdims) + + def cumsum(self, a, axis=None): + if axis is None: + return torch.cumsum(a.flatten(), 0) + else: + return torch.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + if axis is None: + return torch.max(a) + else: + return torch.max(a, axis, keepdim=keepdims)[0] + + def min(self, a, axis=None, keepdims=False): + if axis is None: + return torch.min(a) + else: + return torch.min(a, axis, keepdim=keepdims)[0] + + def maximum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + return torch.maximum(a, b) + + def minimum(self, a, b): + if isinstance(a, int) or isinstance(a, float): + a = torch.tensor([float(a)], dtype=b.dtype, device=b.device) + if isinstance(b, int) or isinstance(b, float): + b = torch.tensor([float(b)], dtype=a.dtype, device=a.device) + return torch.minimum(a, b) + + def dot(self, a, b): + if len(a.shape) == len(b.shape) == 1: + return torch.dot(a, b) + elif len(a.shape) == 2 and len(b.shape) == 1: + return torch.mv(a, b) + else: + return torch.mm(a, b) + + def abs(self, a): + return torch.abs(a) + + def exp(self, a): + return torch.exp(a) + + def log(self, a): + return torch.log(a) + + def sqrt(self, a): + return torch.sqrt(a) + + def norm(self, a): + return torch.sqrt(torch.sum(torch.square(a))) + + def any(self, a): + return torch.any(a) + + def isnan(self, a): + return torch.isnan(a) + + def isinf(self, a): + return torch.isinf(a) + + def einsum(self, subscripts, *operands): + return torch.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + sorted0, indices = torch.sort(a, dim=axis) + return sorted0 + + def argsort(self, a, axis=-1): + sorted, indices = torch.sort(a, dim=axis) + return indices + + def flip(self, a, axis=None): + if axis is None: + return torch.flip(a, tuple(i for i in range(len(a.shape)))) + if isinstance(axis, int): + return torch.flip(a, (axis,)) + else: + return torch.flip(a, dims=axis) diff --git a/ot/bregman.py b/ot/bregman.py index 559db14..b10effd 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -19,7 +19,8 @@ import warnings import numpy as np from scipy.optimize import fmin_l_bfgs_b -from ot.utils import unif, dist +from ot.utils import unif, dist, list_to_array +from .backend import get_backend def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -43,17 +44,36 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in [2]_ + + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -69,25 +89,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - - Returns ------- - gamma : ndarray, shape (dim_a, dim_b) + gamma : array-like, shape (dim_a, dim_b) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -166,17 +170,35 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - a and b are source and target weights (histograms, both sum to 1) + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [2]_ + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. + Parameters ---------- - a : ndarray, shape (dim_a,) + a : array-like, shape (dim_a,) samples weights in the source domain - b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : ndarray, shape (dim_a, dim_b) + M : array-like, shape (dim_a, dim_b) loss matrix reg : float Regularization term >0 @@ -191,28 +213,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :any:`ot.bregman.sinkhorn_epsilon_scaling` that relie on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. - Returns ------- - W : (n_hists) ndarray + W : (n_hists) float/array-like Optimal transportation loss for the given parameters log : dict log dictionary return only if log==True in parameters + Examples -------- @@ -247,7 +255,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] """ - b = np.asarray(b, dtype=np.float64) + + b = list_to_array(b) if len(b.shape) < 2: b = b[:, None] @@ -339,14 +348,14 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, """ - a = np.asarray(a, dtype=np.float64) - b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) # init data dim_a = len(a) @@ -363,21 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) + K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K cpt = 0 @@ -386,13 +387,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, uprev = u vprev = v - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) + KtransposeU = nx.dot(K.T, u) + v = b / KtransposeU + u = 1. / nx.dot(Kp, v) - if (np.any(KtransposeU == 0) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if (nx.any(KtransposeU == 0) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -403,11 +404,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: - np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) + tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v) else: # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - np.einsum('i,ij,j->j', u, K, v, out=tmp2) - err = np.linalg.norm(tmp2 - b) # violation of marginal + tmp2 = nx.einsum('i,ij,j->j', u, K, v) + err = nx.norm(tmp2 - b) # violation of marginal if log: log['err'].append(err) @@ -422,7 +423,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log['v'] = v if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index 7478fb9..e939610 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -25,6 +25,8 @@ result of the function with parameter ``to_numpy=False``. # # License: MIT License +import warnings + from . import bregman from . import da from .bregman import sinkhorn @@ -34,7 +36,7 @@ from . import utils from .utils import dist, to_gpu, to_np - +warnings.warn('This module will be deprecated in the next minor release of POT', category=DeprecationWarning) __all__ = ["utils", "dist", "sinkhorn", diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d5c3a5e..c8c9da6 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -18,8 +18,9 @@ from . import cvx from .cvx import barycenter # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from ..utils import dist +from ..utils import dist, list_to_array from ..utils import parmap +from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] @@ -176,8 +177,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): r"""Solves the Earth Movers distance problem and returns the OT matrix - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + .. math:: \gamma = arg\min_\gamma <\gamma,M>_F s.t. \gamma 1 = a @@ -189,37 +189,41 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): - M is the metric cost matrix - a and b are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. warning:: Note that the M matrix in numpy needs to be a C-order + numpy.array in float64 format. It will be converted if not in this + format + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 - Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) - numItermax : int, optional (default=100000) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual - variables. Otherwise returns only the optimal transportation matrix. + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) - If True, centers the dual potential using function + If True, centers the dual potential using function :ref:`center_ot_dual`. Returns ------- - gamma: (ns x nt) numpy.ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is true, a dictionary containing the cost and dual - variables and exit status + gamma: array-like, shape (ns, nt) + Optimal transportation matrix for the given + parameters + log: dict, optional + If input log is true, a dictionary containing the + cost and dual variables and exit status Examples @@ -232,26 +236,37 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): >>> a=[.5,.5] >>> b=[.5,.5] >>> M=[[0.,1.],[1.,0.]] - >>> ot.emd(a,b,M) + >>> ot.emd(a, b, M) array([[0.5, 0. ], [0. , 0.5]]) References ---------- - .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. - (2011, December). Displacement interpolation using Lagrangian mass - transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. - 158). ACM. + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, + December). Displacement interpolation using Lagrangian mass transport. + In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. See Also -------- - ot.bregman.sinkhorn : Entropic regularized OT - ot.optim.cg : General regularized OT""" - + ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General + regularized OT""" + + # convert to numpy if list + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) + + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -262,6 +277,11 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0), + b.sum(0), err_msg='a and b vector must have the same sum') + b=b*a.sum()/b.sum() + asel = a != 0 bsel = b != 0 @@ -277,12 +297,12 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): if log: log = {} log['cost'] = cost - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code - return G, log - return G + return nx.from_numpy(G, type_as=M0), log + return nx.from_numpy(G, type_as=M0) def emd2(a, b, M, processes=multiprocessing.cpu_count(), @@ -303,20 +323,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), - M is the metric cost matrix - a and b are the sample weights - .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 - format. + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) numpy.ndarray, float64 + a : (ns,) array-like, float64 Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 + b : (nt,) array-like, float64 Target histogram (uniform weight if empty list) - M : (ns,nt) numpy.ndarray, float64 - Loss matrix (c-order array with type float64) + M : (ns,nt) array-like, float64 + Loss matrix (for numpy c-order array with type float64) processes : int, optional (default=nb cpu) Nb of processes used for multiple emd computation (not used on windows) numItermax : int, optional (default=100000) @@ -333,9 +352,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), Returns ------- - W: float + W: float, array-like Optimal transportation loss for the given parameters - log: dictnp + log: dict If input log is true, a dictionary containing dual variables and exit status @@ -367,12 +386,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT""" + a, b, M = list_to_array(a, b, M) + + a0, b0, M0 = a, b, M + nx = get_backend(M0, a0, b0) + + # convert to numpy + M = nx.to_numpy(M) + a = nx.to_numpy(a) + b = nx.to_numpy(b) + a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) + M = np.asarray(M, dtype=np.float64, order= 'C') # problem with pikling Forks - if sys.platform.endswith('win32'): + if sys.platform.endswith('win32') or not nx.__name__ == 'numpy': processes = 1 # if empty array given then use uniform distributions @@ -400,12 +429,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), result_code_string = check_result(result_code) log = {} + G = nx.from_numpy(G, type_as=M0) if return_matrix: log['G'] = G - log['u'] = u - log['v'] = v + log['u'] = nx.from_numpy(u, type_as=a0) + log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0,b0, M0), (log['u'], log['v'], G)) return [cost, log] else: def f(b): @@ -418,6 +450,11 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) + G = nx.from_numpy(G, type_as=M0) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0,b0, M0), (nx.from_numpy(u, type_as=a0), + nx.from_numpy(v, type_as=b0),G)) + check_result(result_code) return cost @@ -637,6 +674,10 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, if b.ndim == 0 or len(b) == 0: b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0] + # ensure that same mass + np.testing.assert_almost_equal(a.sum(0),b.sum(0),err_msg='a and b vector must have the same sum') + b=b*a.sum()/b.sum() + x_a_1d = x_a.reshape((-1,)) x_b_1d = x_b.reshape((-1,)) perm_a = np.argsort(x_a_1d) diff --git a/ot/utils.py b/ot/utils.py index 544c569..4dac0c5 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -16,6 +16,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature +from .backend import get_backend __time_tic_toc = time.time() @@ -41,8 +42,11 @@ def toq(): def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): """Compute kernel matrix""" + + nx = get_backend(x1, x2) + if method.lower() in ['gaussian', 'gauss', 'rbf']: - K = np.exp(-dist(x1, x2) / (2 * sigma**2)) + K = nx.exp(-dist(x1, x2) / (2 * sigma**2)) return K @@ -52,6 +56,66 @@ def laplacian(x): return L +def list_to_array(*lst): + """ Convert a list if in numpy format """ + if len(lst) > 1: + return [np.array(a) if isinstance(a, list) else a for a in lst] + else: + return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + + +def proj_simplex(v, z=1): + r""" compute the closest point (orthogonal projection) on the + generalized (n-1)-simplex of a vector v wrt. to the Euclidean + distance, thus solving: + .. math:: + \mathcal{P}(w) \in arg\min_\gamma || \gamma - v ||_2 + + s.t. \gamma^T 1= z + + \gamma\geq 0 + + If v is a 2d array, compute all the projections wrt. axis 0 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + Parameters + ---------- + v : {array-like}, shape (n, d) + z : int, optional + 'size' of the simplex (each vectors sum to z, 1 by default) + + Returns + ------- + h : ndarray, shape (n,d) + Array of projections on the simplex + """ + nx = get_backend(v) + n = v.shape[0] + if v.ndim == 1: + d1 = 1 + v = v[:, None] + else: + d1 = 0 + d = v.shape[1] + + # sort u in ascending order + u = nx.sort(v, axis=0) + # take the descending order + u = nx.flip(u, 0) + cssv = nx.cumsum(u, axis=0) - z + ind = nx.arange(n, type_as=v)[:, None] + 1 + cond = u - cssv / ind > 0 + rho = nx.sum(cond, 0) + theta = cssv[rho - 1, nx.arange(d)] / rho + w = nx.maximum(v - theta[None, :], nx.zeros(v.shape, type_as=v)) + if d1: + return w[:, 0] + else: + return w + + def unif(n): """ return a uniform histogram of length n (simplex) @@ -84,52 +148,68 @@ def euclidean_distances(X, Y, squared=False): """ Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + Parameters ---------- X : {array-like}, shape (n_samples_1, n_features) Y : {array-like}, shape (n_samples_2, n_features) squared : boolean, optional Return squared Euclidean distances. + Returns ------- distances : {array}, shape (n_samples_1, n_samples_2) """ - XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis] - YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :] - distances = np.dot(X, Y.T) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + + nx = get_backend(X, Y) + + a2 = nx.einsum('ij,ij->i', X, X) + b2 = nx.einsum('ij,ij->i', Y, Y) + + c = -2 * nx.dot(X, Y.T) + c += a2[:, None] + c += b2[None, :] + + c = nx.maximum(c, 0) + + if not squared: + c = nx.sqrt(c) + if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 - return distances if squared else np.sqrt(distances, out=distances) + c = c * (1 - nx.eye(X.shape[0], type_as=c)) + + return c def dist(x1, x2=None, metric='sqeuclidean'): - """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist + """Compute distance between samples in x1 and x2 + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. Parameters ---------- - x1 : ndarray, shape (n1,d) + x1 : array-like, shape (n1,d) matrix with n1 samples of size d - x2 : array, shape (n2,d), optional + x2 : array-like, shape (n2,d), optional matrix with n2 samples of size d (if None then x2=x1) metric : str | callable, optional - Name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock', - 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski', - 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', + 'sqeuclidean' or 'euclidean' on all backends. On numpy the function also + accepts from the scipy.spatial.distance.cdist function : 'braycurtis', + 'canberra', 'chebyshev', 'cityblock', 'correlation', 'cosine', 'dice', + 'euclidean', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', + 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'. Returns ------- - M : np.array (n1,n2) + M : array-like, shape (n1, n2) distance matrix computed with given metric """ @@ -137,7 +217,13 @@ def dist(x1, x2=None, metric='sqeuclidean'): x2 = x1 if metric == "sqeuclidean": return euclidean_distances(x1, x2, squared=True) - return cdist(x1, x2, metric=metric) + elif metric == "euclidean": + return euclidean_distances(x1, x2, squared=False) + else: + if not get_backend(x1, x2).__name__ == 'numpy': + raise NotImplementedError() + else: + return cdist(x1, x2, metric=metric) def dist0(n, method='lin_square'): diff --git a/requirements.txt b/requirements.txt index 331dd57..4353247 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,7 @@ pymanopt==0.2.4; python_version <'3' pymanopt; python_version >= '3' cvxopt scikit-learn +torch +jax +jaxlib pytest \ No newline at end of file diff --git a/test/test_backend.py b/test/test_backend.py new file mode 100644 index 0000000..bc5b00c --- /dev/null +++ b/test/test_backend.py @@ -0,0 +1,364 @@ +"""Tests for backend module """ + +# Author: Remi Flamary +# +# License: MIT License + +import ot +import ot.backend +from ot.backend import torch, jax + +import pytest + +import numpy as np +from numpy.testing import assert_array_almost_equal_nulp + +from ot.backend import get_backend, get_backend_list, to_numpy + + +backend_list = get_backend_list() + + +def test_get_backend_list(): + + lst = get_backend_list() + + assert len(lst) > 0 + assert isinstance(lst[0], ot.backend.NumpyBackend) + + +@pytest.mark.parametrize('nx', backend_list) +def test_to_numpy(nx): + + v = nx.zeros(10) + M = nx.ones((10, 10)) + + v2 = to_numpy(v) + assert isinstance(v2, np.ndarray) + + v2, M2 = to_numpy(v, M) + assert isinstance(M2, np.ndarray) + + +def test_get_backend(): + + A = np.zeros((3, 2)) + B = np.zeros((3, 1)) + + nx = get_backend(A) + assert nx.__name__ == 'numpy' + + nx = get_backend(A, B) + assert nx.__name__ == 'numpy' + + # error if no parameters + with pytest.raises(ValueError): + get_backend() + + # error if unknown types + with pytest.raises(ValueError): + get_backend(1, 2.0) + + # test torch + if torch: + + A2 = torch.from_numpy(A) + B2 = torch.from_numpy(B) + + nx = get_backend(A2) + assert nx.__name__ == 'torch' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'torch' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + + if jax: + + A2 = jax.numpy.array(A) + B2 = jax.numpy.array(B) + + nx = get_backend(A2) + assert nx.__name__ == 'jax' + + nx = get_backend(A2, B2) + assert nx.__name__ == 'jax' + + # test not unique types in input + with pytest.raises(ValueError): + get_backend(A, B2) + + +@pytest.mark.parametrize('nx', backend_list) +def test_convert_between_backends(nx): + + A = np.zeros((3, 2)) + B = np.zeros((3, 1)) + + A2 = nx.from_numpy(A) + B2 = nx.from_numpy(B) + + assert isinstance(A2, nx.__type__) + assert isinstance(B2, nx.__type__) + + nx2 = get_backend(A2, B2) + + assert nx2.__name__ == nx.__name__ + + assert_array_almost_equal_nulp(nx.to_numpy(A2), A) + assert_array_almost_equal_nulp(nx.to_numpy(B2), B) + + +def test_empty_backend(): + + rnd = np.random.RandomState(0) + M = rnd.randn(10, 3) + v = rnd.randn(3) + + nx = ot.backend.Backend() + + with pytest.raises(NotImplementedError): + nx.from_numpy(M) + with pytest.raises(NotImplementedError): + nx.to_numpy(M) + with pytest.raises(NotImplementedError): + nx.set_gradients(0, 0, 0) + with pytest.raises(NotImplementedError): + nx.zeros((10, 3)) + with pytest.raises(NotImplementedError): + nx.ones((10, 3)) + with pytest.raises(NotImplementedError): + nx.arange(10, 1, 2) + with pytest.raises(NotImplementedError): + nx.full((10, 3), 3.14) + with pytest.raises(NotImplementedError): + nx.eye((10, 3)) + with pytest.raises(NotImplementedError): + nx.sum(M) + with pytest.raises(NotImplementedError): + nx.cumsum(M) + with pytest.raises(NotImplementedError): + nx.max(M) + with pytest.raises(NotImplementedError): + nx.min(M) + with pytest.raises(NotImplementedError): + nx.maximum(v, v) + with pytest.raises(NotImplementedError): + nx.minimum(v, v) + with pytest.raises(NotImplementedError): + nx.abs(M) + with pytest.raises(NotImplementedError): + nx.log(M) + with pytest.raises(NotImplementedError): + nx.exp(M) + with pytest.raises(NotImplementedError): + nx.sqrt(M) + with pytest.raises(NotImplementedError): + nx.dot(v, v) + with pytest.raises(NotImplementedError): + nx.norm(M) + with pytest.raises(NotImplementedError): + nx.exp(M) + with pytest.raises(NotImplementedError): + nx.any(M) + with pytest.raises(NotImplementedError): + nx.isnan(M) + with pytest.raises(NotImplementedError): + nx.isinf(M) + with pytest.raises(NotImplementedError): + nx.einsum('ij->i', M) + with pytest.raises(NotImplementedError): + nx.sort(M) + with pytest.raises(NotImplementedError): + nx.argsort(M) + with pytest.raises(NotImplementedError): + nx.flip(M) + + +@pytest.mark.parametrize('backend', backend_list) +def test_func_backends(backend): + + rnd = np.random.RandomState(0) + M = rnd.randn(10, 3) + v = rnd.randn(3) + val = np.array([1.0]) + + lst_tot = [] + + for nx in [ot.backend.NumpyBackend(), backend]: + + print('Backend: ', nx.__name__) + + lst_b = [] + lst_name = [] + + Mb = nx.from_numpy(M) + vb = nx.from_numpy(v) + val = nx.from_numpy(val) + + A = nx.set_gradients(val, v, v) + lst_b.append(nx.to_numpy(A)) + lst_name.append('set_gradients') + + A = nx.zeros((10, 3)) + A = nx.zeros((10, 3), type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('zeros') + + A = nx.ones((10, 3)) + A = nx.ones((10, 3), type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('ones') + + A = nx.arange(10, 1, 2) + lst_b.append(nx.to_numpy(A)) + lst_name.append('arange') + + A = nx.full((10, 3), 3.14) + A = nx.full((10, 3), 3.14, type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('full') + + A = nx.eye(10, 3) + A = nx.eye(10, 3, type_as=Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('eye') + + A = nx.sum(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sum') + + A = nx.sum(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sum(axis)') + + A = nx.cumsum(Mb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('cumsum(axis)') + + A = nx.max(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('max') + + A = nx.max(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('max(axis)') + + A = nx.min(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('min') + + A = nx.min(Mb, axis=1, keepdims=True) + lst_b.append(nx.to_numpy(A)) + lst_name.append('min(axis)') + + A = nx.maximum(vb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('maximum') + + A = nx.minimum(vb, 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('minimum') + + A = nx.abs(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('abs') + + A = nx.log(A) + lst_b.append(nx.to_numpy(A)) + lst_name.append('log') + + A = nx.exp(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('exp') + + A = nx.sqrt(nx.abs(Mb)) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sqrt') + + A = nx.dot(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(v,v)') + + A = nx.dot(Mb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(M,v)') + + A = nx.dot(Mb, Mb.T) + lst_b.append(nx.to_numpy(A)) + lst_name.append('dot(M,M)') + + A = nx.norm(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('norm') + + A = nx.any(vb > 0) + lst_b.append(nx.to_numpy(A)) + lst_name.append('any') + + A = nx.isnan(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('isnan') + + A = nx.isinf(vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('isinf') + + A = nx.einsum('ij->i', Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('einsum(ij->i)') + + A = nx.einsum('ij,j->i', Mb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('nx.einsum(ij,j->i)') + + A = nx.einsum('ij->i', Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('nx.einsum(ij->i)') + + A = nx.sort(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('sort') + + A = nx.argsort(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('argsort') + + A = nx.flip(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('flip') + + lst_tot.append(lst_b) + + lst_np = lst_tot[0] + lst_b = lst_tot[1] + + for a1, a2, name in zip(lst_np, lst_b, lst_name): + if not np.allclose(a1, a2): + print('Assert fail on: ', name) + assert np.allclose(a1, a2, atol=1e-7) + + +def test_gradients_backends(): + + rnd = np.random.RandomState(0) + v = rnd.randn(10) + c = rnd.randn(1) + + if torch: + + nx = ot.backend.TorchBackend() + + v2 = torch.tensor(v, requires_grad=True) + c2 = torch.tensor(c, requires_grad=True) + + val = c2 * torch.sum(v2 * v2) + + val2 = nx.set_gradients(val, (v2, c2), (v2, c2)) + + val2.backward() + + assert torch.equal(v2.grad, v2) + assert torch.equal(c2.grad, c2) diff --git a/test/test_bregman.py b/test/test_bregman.py index 1ebd21f..7c5162a 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -9,6 +9,10 @@ import numpy as np import pytest import ot +from ot.backend import get_backend_list +from ot.backend import torch + +backend_list = get_backend_list() def test_sinkhorn(): @@ -30,6 +34,76 @@ def test_sinkhorn(): u, G.sum(0), atol=1e-05) # cf convergence sinkhorn +@pytest.mark.parametrize('nx', backend_list) +def test_sinkhorn_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.sinkhorn(ab, ab, Mb, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_sinkhorn2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.sinkhorn(a, a, M, 1) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.sinkhorn2(ab, ab, Mb, 1) + + np.allclose(G, nx.to_numpy(Gb)) + + +def test_sinkhorn2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.sinkhorn2(a1, b1, M1, 1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + + def test_sinkhorn_empty(): # test sinkhorn n = 100 diff --git a/test/test_gromov.py b/test/test_gromov.py index 43da9fc..81138ca 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -181,7 +181,7 @@ def test_fgw(): M = ot.dist(ys, yt) M /= M.max() - G = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5) + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) # check constratints np.testing.assert_allclose( @@ -242,9 +242,9 @@ def test_fgw_barycenter(): init_X = np.random.randn(n_samples, ys.shape[1]) - X, C = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_X, - p=ot.unif(n_samples), loss_fun='square_loss', - max_iter=100, tol=1e-3) + X, C, log = ot.gromov.fgw_barycenters(n_samples, [ys, yt], [C1, C2], [ot.unif(ns), ot.unif(nt)], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_X, + p=ot.unif(n_samples), loss_fun='square_loss', + max_iter=100, tol=1e-3, log=True) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) diff --git a/test/test_ot.py b/test/test_ot.py index f45e4c9..3e953dc 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,9 +12,12 @@ from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss +from ot.backend import get_backend_list, torch +backend_list = get_backend_list() -def test_emd_dimension_mismatch(): + +def test_emd_dimension_and_mass_mismatch(): # test emd and emd2 for dimension mismatch n_samples = 100 n_features = 2 @@ -29,6 +32,80 @@ def test_emd_dimension_mismatch(): np.testing.assert_raises(AssertionError, ot.emd2, a, a, M) + b = a.copy() + a[0] = 100 + np.testing.assert_raises(AssertionError, ot.emd, a, b, M) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + G = ot.emd(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + Gb = ot.emd(ab, ab, Mb) + + np.allclose(G, nx.to_numpy(Gb)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_emd2_backends(nx): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + val = ot.emd2(a, a, M) + + ab = nx.from_numpy(a) + Mb = nx.from_numpy(M) + + valb = ot.emd2(ab, ab, Mb) + + np.allclose(val, nx.to_numpy(valb)) + + +def test_emd2_gradients(): + n_samples = 100 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples, n_features) + y = rng.randn(n_samples, n_features) + a = ot.utils.unif(n_samples) + + M = ot.dist(x, y) + + if torch: + + a1 = torch.tensor(a, requires_grad=True) + b1 = torch.tensor(a, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.emd2(a1, b1, M1) + + val.backward() + + assert a1.shape == a1.grad.shape + assert b1.shape == b1.grad.shape + assert M1.shape == M1.grad.shape + def test_emd_emd2(): # test emd and emd2 for simple identity @@ -83,7 +160,7 @@ def test_emd_1d_emd2_1d(): np.testing.assert_allclose(np.ones((m,)) / m, G.sum(0)) # check G is similar - np.testing.assert_allclose(G, G_1d) + np.testing.assert_allclose(G, G_1d, atol=1e-15) # check AssertionError is raised if called on non 1d arrays u = np.random.randn(n, 2) @@ -292,16 +369,6 @@ def test_warnings(): ot.emd(a, b, M, numItermax=1) assert "numItermax" in str(w[-1].message) #assert len(w) == 1 - a[0] = 100 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 2 - a[0] = -1 - print('Computing {} EMD '.format(2)) - ot.emd(a, b, M) - assert "infeasible" in str(w[-1].message) - #assert len(w) == 3 def test_dual_variables(): diff --git a/test/test_partial.py b/test/test_partial.py index 121f345..3571e2a 100755 --- a/test/test_partial.py +++ b/test/test_partial.py @@ -129,9 +129,9 @@ def test_partial_wasserstein(): # check constratints np.testing.assert_equal( - G.sum(1) <= p, [True] * len(p)) # cf convergence wasserstein + G.sum(1) - p <= 1e-5, [True] * len(p)) # cf convergence wasserstein np.testing.assert_equal( - G.sum(0) <= q, [True] * len(q)) # cf convergence wasserstein + G.sum(0) - q <= 1e-5, [True] * len(q)) # cf convergence wasserstein np.testing.assert_allclose( np.sum(G), m, atol=1e-04) diff --git a/test/test_utils.py b/test/test_utils.py index db9cda6..76b1faa 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,11 +4,47 @@ # # License: MIT License - +import pytest import ot import numpy as np import sys +from ot.backend import get_backend_list + +backend_list = get_backend_list() + + +@pytest.mark.parametrize('nx', backend_list) +def test_proj_simplex(nx): + n = 10 + rng = np.random.RandomState(0) + + # test on matrix when projection is done on axis 0 + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # all projections should sum to 3 + proj = ot.utils.proj_simplex(x1, 3) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = 3 * np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + + # tets on vector + x = rng.randn(n) + x1 = nx.from_numpy(x) + + # all projections should sum to 1 + proj = ot.utils.proj_simplex(x1) + l1 = np.sum(nx.to_numpy(proj), axis=0) + l2 = np.ones(2) + np.testing.assert_allclose(l1, l2, atol=1e-5) + def test_parmap(): @@ -45,8 +81,8 @@ def test_tic_toc(): def test_kernel(): n = 100 - - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) K = ot.utils.kernel(x, x) @@ -67,7 +103,8 @@ def test_dist(): n = 100 - x = np.random.randn(n, 2) + rng = np.random.RandomState(0) + x = rng.randn(n, 2) D = np.zeros((n, n)) for i in range(n): @@ -78,8 +115,27 @@ def test_dist(): D3 = ot.dist(x) # dist shoul return squared euclidean - np.testing.assert_allclose(D, D2) - np.testing.assert_allclose(D, D3) + np.testing.assert_allclose(D, D2, atol=1e-14) + np.testing.assert_allclose(D, D3, atol=1e-14) + + +@ pytest.mark.parametrize('nx', backend_list) +def test_dist_backends(nx): + + n = 100 + rng = np.random.RandomState(0) + x = rng.randn(n, 2) + x1 = nx.from_numpy(x) + + lst_metric = ['euclidean', 'sqeuclidean'] + + for metric in lst_metric: + + D = ot.dist(x, x, metric=metric) + D1 = ot.dist(x1, x1, metric=metric) + + # low atol because jax forces float32 + np.testing.assert_allclose(D, nx.to_numpy(D1), atol=1e-5) def test_dist0(): @@ -95,9 +151,11 @@ def test_dots(): n1, n2, n3, n4 = 100, 50, 200, 100 - A = np.random.randn(n1, n2) - B = np.random.randn(n2, n3) - C = np.random.randn(n3, n4) + rng = np.random.RandomState(0) + + A = rng.randn(n1, n2) + B = rng.randn(n2, n3) + C = rng.randn(n3, n4) X1 = ot.utils.dots(A, B, C) -- cgit v1.2.3 From d693ac25988dd557cb1ee7fc96f3a656f7d4301c Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Wed, 2 Jun 2021 12:59:35 +0200 Subject: [WIP] Add Wasserstein GAN and debug memory leak (#254) * add example and debug memory leak * print stuff * speedup gallery * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * test cells * proper header gan exmaple * cleanup sections * last changes ? Co-authored-by: Alexandre Gramfort --- examples/backends/plot_wass2_gan_torch.py | 195 +++++++++++++++++++++ .../domain-adaptation/plot_otda_color_images.py | 2 +- .../plot_otda_mapping_colors_images.py | 2 +- ot/backend.py | 34 ++-- 4 files changed, 220 insertions(+), 13 deletions(-) create mode 100644 examples/backends/plot_wass2_gan_torch.py (limited to 'examples') diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py new file mode 100644 index 0000000..8f50022 --- /dev/null +++ b/examples/backends/plot_wass2_gan_torch.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- +r""" +======================================== +Wasserstein 2 Minibatch GAN with PyTorch +======================================== + +In this example we train a Wasserstein GAN using Wasserstein 2 on minibatches +as a distribution fitting term. + +We want to train a generator :math:`G_\theta` that generates realistic +data from random noise drawn form a Gaussian :math:`\mu_n` distribution so +that the data is indistinguishable from true data in the data distribution +:math:`\mu_d`. To this end Wasserstein GAN [Arjovsky2017] aim at optimizing +the parameters :math:`\theta` of the generator with the following +optimization problem: + +.. math:: + \min_{\theta} W(\mu_d,G_\theta\#\mu_n) + + +In practice we do not have access to the full distribution :math:`\mu_d` but +samples and we cannot compute the Wasserstein distance for lare dataset. +[Arjovsky2017] proposed to approximate the dual potential of Wasserstein 1 +with a neural network recovering an optimization problem similar to GAN. +In this example +we will optimize the expectation of the Wasserstein distance over minibatches +at each iterations as proposed in [Genevay2018]. Optimizing the Minibatches +of the Wasserstein distance has been studied in[Fatras2019]. + +[Arjovsky2017] Arjovsky, M., Chintala, S., & Bottou, L. (2017, July). +Wasserstein generative adversarial networks. In International conference +on machine learning (pp. 214-223). PMLR. + +[Genevay2018] Genevay, Aude, Gabriel Peyré, and Marco Cuturi. "Learning generative models +with sinkhorn divergences." International Conference on Artificial Intelligence +and Statistics. PMLR, 2018. + +[Fatras2019] Fatras, K., Zine, Y., Flamary, R., Gribonval, R., & Courty, N. +(2020, June). Learning with minibatch Wasserstein: asymptotic and gradient +properties. In the 23nd International Conference on Artificial Intelligence +and Statistics (Vol. 108). + +""" + +# Author: Remi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import torch +from torch import nn +import ot + + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) +sigma = 0.1 +n_dims = 2 +n_features = 2 + + +def get_data(n_samples): + c = torch.rand(size=(n_samples, 1)) + angle = c * 2 * np.pi + x = torch.cat((torch.cos(angle), torch.sin(angle)), 1) + x += torch.randn(n_samples, 2) * sigma + return x + + +# %% +# Plot data +# --------- + +# plot the distributions +x = get_data(500) +pl.figure(1) +pl.scatter(x[:, 0], x[:, 1], label='Data samples from $\mu_d$', alpha=0.5) +pl.title('Data distribution') +pl.legend() + + +# %% +# Generator Model +# --------------- + +# define the MLP model +class Generator(torch.nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.fc1 = nn.Linear(n_features, 200) + self.fc2 = nn.Linear(200, 500) + self.fc3 = nn.Linear(500, n_dims) + self.relu = torch.nn.ReLU() # instead of Heaviside step fn + + def forward(self, x): + output = self.fc1(x) + output = self.relu(output) # instead of Heaviside step fn + output = self.fc2(output) + output = self.relu(output) + output = self.fc3(output) + return output + +# %% +# Training the model +# ------------------ + + +G = Generator() +optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) + +# number of iteration and size of the batches +n_iter = 500 +size_batch = 500 + +# generate statis samples to see their trajectory along training +n_visu = 100 +xnvisu = torch.randn(n_visu, n_features) +xvisu = torch.zeros(n_iter, n_visu, n_dims) + +ab = torch.ones(size_batch) / size_batch +losses = [] + + +for i in range(n_iter): + + # generate noise samples + xn = torch.randn(size_batch, n_features) + + # generate data samples + xd = get_data(size_batch) + + # generate sample along iterations + xvisu[i, :, :] = G(xnvisu).detach() + + # generate smaples and compte distance matrix + xg = G(xn) + M = ot.dist(xg, xd) + + loss = ot.emd2(ab, ab, M) + losses.append(float(loss.detach())) + + if i % 10 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + + loss.backward() + optimizer.step() + + del M + +pl.figure(2) +pl.semilogy(losses) +pl.grid() +pl.title('Wasserstein distance') +pl.xlabel("Iterations") + + +# %% +# Plot trajectories of generated samples along iterations +# ------------------------------------------------------- + + +pl.figure(3, (10, 10)) + +ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499] + +for i in range(9): + pl.subplot(3, 3, i + 1) + pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) + pl.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) + pl.xticks(()) + pl.yticks(()) + pl.title('Iter. {}'.format(ivisu[i])) + if i == 0: + pl.legend() + +# %% +# Generate and visualize data +# --------------------------- + +size_batch = 500 +xd = get_data(size_batch) +xn = torch.randn(size_batch, 2) +x = G(xn).detach().numpy() + +pl.figure(4) +pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5) +pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) +pl.title('Sources and Target distributions') +pl.legend() diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index d70f1fc..6218b13 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -53,7 +53,7 @@ X1 = im2mat(I1) X2 = im2mat(I2) # training samples -nb = 1000 +nb = 500 idx1 = r.randint(X1.shape[0], size=(nb,)) idx2 = r.randint(X2.shape[0], size=(nb,)) diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index aa41d22..72010a6 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -56,7 +56,7 @@ X1 = im2mat(I1) X2 = im2mat(I2) # training samples -nb = 1000 +nb = 500 idx1 = r.randint(X1.shape[0], size=(nb,)) idx2 = r.randint(X2.shape[0], size=(nb,)) diff --git a/ot/backend.py b/ot/backend.py index d68f5cf..8f46900 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -389,6 +389,26 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + def __init__(self): + + from torch.autograd import Function + + # define a function that takes inputs val and grads + # ad returns a val tensor with proper gradients + class ValFunction(Function): + + @staticmethod + def forward(ctx, val, grads, *inputs): + ctx.grads = grads + return val + + @staticmethod + def backward(ctx, grad_output): + # the gradients are grad + return (None, None) + ctx.grads + + self.ValFunction = ValFunction + def to_numpy(self, a): return a.cpu().detach().numpy() @@ -399,20 +419,12 @@ class TorchBackend(Backend): return torch.as_tensor(a, dtype=type_as.dtype, device=type_as.device) def set_gradients(self, val, inputs, grads): - from torch.autograd import Function - # define a function that takes inputs and return val - class ValFunction(Function): - @staticmethod - def forward(ctx, *inputs): - return val + Func = self.ValFunction() - @staticmethod - def backward(ctx, grad_output): - # the gradients are grad - return grads + res = Func.apply(val, grads, *inputs) - return ValFunction.apply(*inputs) + return res def zeros(self, shape, type_as=None): if type_as is None: -- cgit v1.2.3 From 982510eb5085a0edd7a00fb96a308854957d32bf Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Tue, 8 Jun 2021 22:32:18 +0200 Subject: [MRG] Update example GAN to avoid the 10 minute CircleCI limit (#258) * shortened example GAN * pep8 and typo * awesome animation * small eror pep8 * add animation to doc * better timing animation * tune step --- docs/source/conf.py | 4 +++- examples/backends/plot_wass2_gan_torch.py | 40 +++++++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 5 deletions(-) (limited to 'examples') diff --git a/docs/source/conf.py b/docs/source/conf.py index 3a11798..9b5a719 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -337,7 +337,8 @@ texinfo_documents = [ intersphinx_mapping = {'python': ('https://docs.python.org/3', None), 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), - 'matplotlib': ('http://matplotlib.org/', None)} + 'matplotlib': ('http://matplotlib.org/', None), + 'torch': ('https://pytorch.org/docs/stable/', None)} sphinx_gallery_conf = { 'examples_dirs': ['../../examples', '../../examples/da'], @@ -345,6 +346,7 @@ sphinx_gallery_conf = { 'backreferences_dir': 'gen_modules/backreferences', 'inspect_global_variables' : True, 'doc_module' : ('ot','numpy','scipy','pylab'), + 'matplotlib_animations': True, 'reference_url': { 'ot': None} } diff --git a/examples/backends/plot_wass2_gan_torch.py b/examples/backends/plot_wass2_gan_torch.py index 8f50022..ca5b3c9 100644 --- a/examples/backends/plot_wass2_gan_torch.py +++ b/examples/backends/plot_wass2_gan_torch.py @@ -50,6 +50,7 @@ and Statistics (Vol. 108). import numpy as np import matplotlib.pyplot as pl +import matplotlib.animation as animation import torch from torch import nn import ot @@ -112,10 +113,10 @@ class Generator(torch.nn.Module): G = Generator() -optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) +optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5) # number of iteration and size of the batches -n_iter = 500 +n_iter = 200 # set to 200 for doc build but 1000 is better ;) size_batch = 500 # generate statis samples to see their trajectory along training @@ -167,7 +168,7 @@ pl.xlabel("Iterations") pl.figure(3, (10, 10)) -ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499] +ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199] for i in range(9): pl.subplot(3, 3, i + 1) @@ -179,6 +180,37 @@ for i in range(9): if i == 0: pl.legend() +# %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(4, (8, 8)) + + +def _update_plot(i): + pl.clf() + pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) + pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) + pl.xticks(()) + pl.yticks(()) + pl.xlim((-1.5, 1.5)) + pl.ylim((-1.5, 1.5)) + pl.title('Iter. {}'.format(i)) + return 1 + + +i = 0 +pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) +pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) +pl.xticks(()) +pl.yticks(()) +pl.xlim((-1.5, 1.5)) +pl.ylim((-1.5, 1.5)) +pl.title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000) + # %% # Generate and visualize data # --------------------------- @@ -188,7 +220,7 @@ xd = get_data(size_batch) xn = torch.randn(size_batch, 2) x = G(xn).detach().numpy() -pl.figure(4) +pl.figure(5) pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5) pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) pl.title('Sources and Target distributions') -- cgit v1.2.3 From e0ba31ce39a7d9e65e50ea970a574b3db54e4207 Mon Sep 17 00:00:00 2001 From: Tanguy Date: Fri, 17 Sep 2021 18:36:33 +0200 Subject: [MRG] Implementation of two news algorithms: SaGroW and PoGroW. (#275) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add two new algorithms to solve Gromov Wasserstein: Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein. * Correct some lines in SaGroW and PoGroW to follow pep8 guide. * Change nb_samples name. Use rdm state. Change symmetric check. * Change names of len(p) and len(q) in SaGroW and PoGroW. * Re-add some deleted lines in the comments of gromov.py Co-authored-by: Rémi Flamary --- README.md | 4 + examples/gromov/plot_gromov.py | 34 ++++ ot/gromov.py | 376 +++++++++++++++++++++++++++++++++++++++++ test/test_gromov.py | 88 +++++++++- 4 files changed, 496 insertions(+), 6 deletions(-) (limited to 'examples') diff --git a/README.md b/README.md index 6a2cf15..266d847 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ POT provides the following generic OT solvers (links to examples): * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] * [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] @@ -198,6 +199,7 @@ The contributors to this library are * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) * [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) +* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -286,3 +288,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 [32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). + +[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index deb2f86..5a362cf 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -104,3 +104,37 @@ pl.imshow(gw, cmap='jet') pl.title('Entropic Gromov Wasserstein') pl.show() + +############################################################################# +# +# Compute GW with a scalable stochastic method with any loss function +# ---------------------------------------------------------------------- + + +def loss(x, y): + return np.abs(x - y) + + +pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100, + log=True) + +sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100, + log=True) + +print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated'])) +print('Variance estimated: ' + str(plog['gw_dist_std'])) +print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated'])) +print('Variance estimated: ' + str(slog['gw_dist_std'])) + + +pl.figure(1, (10, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(pgw.toarray(), cmap='jet') +pl.title('Pointwise Gromov Wasserstein') + +pl.subplot(1, 2, 2) +pl.imshow(sgw, cmap='jet') +pl.title('Sampled Gromov Wasserstein') + +pl.show() diff --git a/ot/gromov.py b/ot/gromov.py index 8f457e9..a27217a 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -16,6 +16,10 @@ import numpy as np from .bregman import sinkhorn from .utils import dist, UndefinedParameter from .optim import cg +from .lp import emd_1d, emd +from .utils import check_random_state + +from scipy.sparse import issparse def init_matrix(C1, C2, p, q, loss_fun='square_loss'): @@ -572,6 +576,378 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 return log['fgw_dist'] +def GW_distance_estimation(C1, C2, p, q, loss_fun, T, + nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): + r""" + Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) + with a fixed transport plan T. + + The function gives an unbiased approximation of the following equation: + + .. math:: + GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - L : Loss function to account for the misfit between the similarity matrices + - T : Matrix with marginal p and q + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or ndarray, shape (ns, nt) + Transport plan matrix, either a sparse csr matrix or + nb_samples_p : int, optional + nb_samples_p is the number of samples (without replacement) along the first dimension of T. + nb_samples_q : int, optional + nb_samples_q is the number of samples along the second dimension of T, for each sample along the first. + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost. + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + generator = check_random_state(random_state) + + len_p = len(p) + len_q = len(q) + + # It is always better to sample from the biggest distribution first. + if len_p < len_q: + p, q = q, p + len_p, len_q = len_q, len_p + C1, C2 = C2, C1 + T = T.T + + if nb_samples_p is None: + if issparse(T): + # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced + nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) + else: + nb_samples_p = len_p + else: + # The number of sample along the first dimension is without replacement. + nb_samples_p = min(nb_samples_p, len_p) + if nb_samples_q is None: + nb_samples_q = 1 + if std: + nb_samples_q = max(2, nb_samples_q) + + index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q)) + + index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + + for i in range(nb_samples_p): + if issparse(T): + T_indexi = T[index_i[i], :].toarray()[0] + T_indexj = T[index_j[i], :].toarray()[0] + else: + T_indexi = T[index_i[i], :] + T_indexj = T[index_j[i], :] + # For each of the row sampled, the column is sampled. + index_k[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) + index_l[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) + + for n in range(nb_samples_q): + list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) + + if std: + std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5 + return np.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) + else: + return np.mean(list_value_sample) + + +def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, + alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. + This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold_plan : float, optional + Deleting very small value in the transport plan. If above zero, it violate the marginal constraints. + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + len_p = len(p) + len_q = len(q) + + generator = check_random_state(random_state) + + index = np.zeros(2, dtype=int) + + # Initialize with default marginal + index[0] = generator.choice(len_p, size=1, p=p) + index[1] = generator.choice(len_q, size=1, p=q) + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + + best_gw_dist_estimated = np.inf + for cpt in range(max_iter): + index[0] = generator.choice(len_p, size=1, p=p) + T_index0 = T[index[0], :].toarray()[0] + index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) + + if alpha == 1: + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + else: + new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = (1 - alpha) * T + alpha * new_T + # To limit the number of non 0, the values bellow the threshold are set to 0. + T.data[T.data < threshold_plan] = 0 + T.eliminate_zeros() + + if cpt % 10 == 0 or cpt == (max_iter - 1): + gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False, random_state=generator) + + if gw_dist_estimated < best_gw_dist_estimated: + best_gw_dist_estimated = gw_dist_estimated + best_T = T.copy() + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T, + random_state=generator) + return best_T, log + return best_T + + +def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, + nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, + random_state=None): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. + This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples_grad : int + Number of samples to approximate the gradient + epsilon : float + Weight of the Kullback-Leiber regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + len_p = len(p) + len_q = len(q) + + generator = check_random_state(random_state) + + # The most natural way to define nb_sample is with a simple integer. + if isinstance(nb_samples_grad, int): + if nb_samples_grad > len_p: + # As the sampling along the first dimension is done without replacement, the rest is reported to the second + # dimension. + nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad + T = np.outer(p, q) + # continue_loop allows to stop the loop if there is several successive small modification of T. + continue_loop = 0 + + # The gradient of GW is more complex if the two matrices are not symmetric. + C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + + for cpt in range(max_iter): + index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False) + Lik = 0 + for i, index0_i in enumerate(index0): + index1 = generator.choice(len_q, + size=nb_samples_grad_q, + p=T[index0_i, :] / T[index0_i, :].sum(), + replace=False) + # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. + if (not C_are_symmetric) and generator.rand(1) > 0.5: + Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples_grad_q)], 1), + np.expand_dims(C2[:, index1], 0)), + axis=2) + else: + Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples_grad_q), :], 2), + np.expand_dims(C2[index1, :], 1)), + axis=0) + + max_Lik = np.max(Lik) + if max_Lik == 0: + continue + # This division by the max is here to facilitate the choice of epsilon. + Lik /= max_Lik + + if epsilon > 0: + # Set to infinity all the numbers bellow exp(-200) to avoid log of 0. + log_T = np.log(np.clip(T, np.exp(-200), 1)) + log_T[log_T == -200] = -np.inf + Lik = Lik - epsilon * log_T + + try: + new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) + except (RuntimeWarning, UserWarning): + print("Warning catched in Sinkhorn: Return last stable T") + break + else: + new_T = emd(a=p, b=q, M=Lik) + + change_T = ((T - new_T) ** 2).mean() + if change_T <= 10e-20: + continue_loop += 1 + if continue_loop > 100: # Number max of low modifications of T + T = new_T.copy() + break + else: + continue_loop = 0 + + if verbose and cpt % 10 == 0: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, change_T)) + T = new_T.copy() + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, random_state=generator) + return T, log + return T + + def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" diff --git a/test/test_gromov.py b/test/test_gromov.py index 56414a8..19d61b1 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -33,7 +33,7 @@ def test_gromov(): G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -54,7 +54,7 @@ def test_gromov(): np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -83,7 +83,7 @@ def test_entropic_gromov(): G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -96,13 +96,89 @@ def test_entropic_gromov(): np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov +def test_pointwise_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.0 + assert log['gw_dist_std'] == 0.0 + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + + assert log['gw_dist_estimated'] == 0.10342276348494964 + assert log['gw_dist_std'] == 0.0015952535464736394 + + +def test_sampled_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.sampled_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.05679474884977278 + assert log['gw_dist_std'] == 0.0005986592106971995 + + def test_gromov_barycenter(): ns = 50 nt = 60 @@ -186,7 +262,7 @@ def test_fgw(): G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( @@ -203,7 +279,7 @@ def test_fgw(): np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( -- cgit v1.2.3 From 7af8c2147d61349f4d99ca33318a8a125e4569aa Mon Sep 17 00:00:00 2001 From: haoran010 <62598274+haoran010@users.noreply.github.com> Date: Mon, 25 Oct 2021 10:47:22 +0200 Subject: [MRG] Regularization path for l2 UOT (#274) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add reg path * debug examples and verify pep8 * pep8 and move the reg path examples in unbalanced folder Co-authored-by: haoran010 Co-authored-by: Rémi Flamary --- examples/unbalanced-partial/plot_regpath.py | 135 +++++ ot/__init__.py | 3 +- ot/regpath.py | 827 ++++++++++++++++++++++++++++ test/test_regpath.py | 64 +++ 4 files changed, 1028 insertions(+), 1 deletion(-) create mode 100644 examples/unbalanced-partial/plot_regpath.py create mode 100644 ot/regpath.py create mode 100644 test/test_regpath.py (limited to 'examples') diff --git a/examples/unbalanced-partial/plot_regpath.py b/examples/unbalanced-partial/plot_regpath.py new file mode 100644 index 0000000..4a51c2d --- /dev/null +++ b/examples/unbalanced-partial/plot_regpath.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +""" +================================================================ +Regularization path of l2-penalized unbalanced optimal transport +================================================================ +This example illustrate the regularization path for 2D unbalanced +optimal transport. We present here both the fully relaxed case +and the semi-relaxed case. + +[Chapel et al., 2021] Chapel, L., Flamary, R., Wu, H., Févotte, C., +and Gasso, G. (2021). Unbalanced optimal transport through non-negative +penalized linear regression. +""" + +# Author: Haoran Wu +# License: MIT License + + +import numpy as np +import matplotlib.pylab as pl +import ot + +############################################################################## +# Generate data +# ------------- + +#%% parameters and data generation + +n = 50 # nb samples + +mu_s = np.array([-1, -1]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +np.random.seed(0) +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +# loss matrix +M = ot.dist(xs, xt) +M /= M.max() + +############################################################################## +# Plot data +# --------- + +#%% plot 2 distribution samples + +pl.figure(1) +pl.scatter(xs[:, 0], xs[:, 1], c='C0', label='Source') +pl.scatter(xt[:, 0], xt[:, 1], c='C1', label='Target') +pl.legend(loc=2) +pl.title('Source and target distributions') +pl.show() + +############################################################################## +# Compute semi-relaxed and fully relaxed regularization paths +# ----------- + +#%% +final_gamma = 1e-8 +t, t_list, g_list = ot.regpath.regularization_path(a, b, M, reg=final_gamma, + semi_relaxed=False) +t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma, + semi_relaxed=True) + + +############################################################################## +# Plot the regularization path +# ---------------- + +#%% fully relaxed l2-penalized UOT + +pl.figure(2) +selected_gamma = [2e-1, 1e-1, 5e-2, 1e-3] +for p in range(4): + tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list, + t_list) + P = tp.reshape((n, n)) + pl.subplot(2, 2, p + 1) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.3) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2, + label='Re-weighted source', alpha=1) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2, + label='Re-weighted target', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'$\ell_2$ UOT $\gamma$={}'.format(selected_gamma[p]), + fontsize=11) + if p < 2: + pl.xticks(()) +pl.show() + + +############################################################################## +# Plot the semi-relaxed regularization path +# ------------------- + +#%% semi-relaxed l2-penalized UOT + +pl.figure(3) +selected_gamma = [10, 1, 1e-1, 1e-2] +for p in range(4): + tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, + t_list2) + P = tp.reshape((n, n)) + pl.subplot(2, 2, p + 1) + if P.sum() > 0: + P = P / P.max() + for i in range(n): + for j in range(n): + if P[i, j] > 0: + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2', + alpha=P[i, j] * 0.3) + pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2) + pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=1, label='Target marginal') + pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * 2 * (1 + p), + label='Source marginal', alpha=1) + pl.plot([], [], color='C2', alpha=0.8, label='OT plan') + pl.title(r'Semi-relaxed $l_2$ UOT $\gamma$={}'.format(selected_gamma[p]), + fontsize=11) + if p < 2: + pl.xticks(()) +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 3b072c6..5bd4bab 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -34,6 +34,7 @@ from . import stochastic from . import unbalanced from . import partial from . import backend +from . import regpath # OT functions from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d @@ -54,4 +55,4 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', - 'smooth', 'stochastic', 'unbalanced', 'partial'] + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/regpath.py b/ot/regpath.py new file mode 100644 index 0000000..269937a --- /dev/null +++ b/ot/regpath.py @@ -0,0 +1,827 @@ +# -*- coding: utf-8 -*- +""" +Regularization path OT solvers +""" + +# Author: Haoran Wu +# License: MIT License + +import numpy as np +import scipy.sparse as sp + + +def recast_ot_as_lasso(a, b, C): + r"""This function recasts the l2-penalized UOT problem as a Lasso problem + + Recall the l2-penalized UOT problem defined in [Chapel et al., 2021] + .. math:: + UOT = \min_T + \lambda \|T 1_m - a\|_2^2 + + \lambda \|T^T 1_n - b\|_2^2 + s.t. + T \geq 0 + where : + - C is the (dim_a, dim_b) metric cost matrix + - :math:`\lambda` is the l2-regularization coefficient + - a and b are source and target distributions + - T is the transport plan to optimize + + The problem above can be reformulated to a non-negative penalized + linear regression problem, particularly Lasso + .. math:: + UOT2 = \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + s.t. + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] + - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, + see [Chapel et al., 2021] for the design of H. The matrix product H t + computes both the source marginal and the target marginal. + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + Returns + ------- + H : np.ndarray (dim_a+dim_b, dim_a*dim_b) + Auxiliary matrix constituted by 0 and 1 + y : np.ndarray (ns + nt, ) + Concatenation of histogram a and histogram b + c : np.ndarray (ns * nt, ) + Flattened array of cost matrix + Examples + -------- + >>> import ot + >>> a = np.array([0.2, 0.3, 0.5]) + >>> b = np.array([0.1, 0.9]) + >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]]) + >>> H, y, c = ot.regpath.recast_ot_as_lasso(a, b, C) + >>> H.toarray() + array([[1., 1., 0., 0., 0., 0.], + [0., 0., 1., 1., 0., 0.], + [0., 0., 0., 0., 1., 1.], + [1., 0., 1., 0., 1., 0.], + [0., 1., 0., 1., 0., 1.]]) + >>> y + array([0.2, 0.3, 0.5, 0.1, 0.9]) + >>> c + array([16., 25., 28., 16., 40., 36.]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + dim_a = np.shape(a)[0] + dim_b = np.shape(b)[0] + y = np.concatenate((a, b)) + c = C.flatten() + jHa = np.arange(dim_a * dim_b) + iHa = np.repeat(np.arange(dim_a), dim_b) + jHb = np.arange(dim_a * dim_b) + iHb = np.tile(np.arange(dim_b), dim_a) + dim_a + j = np.concatenate((jHa, jHb)) + i = np.concatenate((iHa, iHb)) + H = sp.csc_matrix((np.ones(dim_a * dim_b * 2), (i, j)), + shape=(dim_a + dim_b, dim_a * dim_b)) + return H, y, c + + +def recast_semi_relaxed_as_lasso(a, b, C): + r"""This function recasts the semi-relaxed l2-UOT problem as Lasso problem + + .. math:: + semi-relaxed UOT = \min_T + \lambda \|T 1_m - a\|_2^2 + s.t. + T^T 1_n = b + t \geq 0 + where : + - C is the (dim_a, dim_b) metric cost matrix + - :math:`\lambda` is the l2-regularization coefficient + - a and b are source and target distributions + - T is the transport plan to optimize + + The problem above can be reformulated as follows + .. math:: + semi-relaxed UOT2 = \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + s.t. + H_c t = b + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - H_r is a (dim_a, dim_a * dim_b) metric matrix, + which computes the sum along the rows of transport plan T + - H_c is a (dim_b, dim_a * dim_b) metric matrix, + which computes the sum along the columns of transport plan T + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + Returns + ------- + Hr : np.ndarray (dim_a, dim_a * dim_b) + Auxiliary matrix constituted by 0 and 1, which computes + the sum along the rows of transport plan T + Hc : np.ndarray (dim_b, dim_a * dim_b) + Auxiliary matrix constituted by 0 and 1, which computes + the sum along the columns of transport plan T + c : np.ndarray (ns * nt, ) + Flattened array of cost matrix + Examples + -------- + >>> import ot + >>> a = np.array([0.2, 0.3, 0.5]) + >>> b = np.array([0.1, 0.9]) + >>> C = np.array([[16., 25.], [28., 16.], [40., 36.]]) + >>> Hr,Hc,c = ot.regpath.recast_semi_relaxed_as_lasso(a, b, C) + >>> Hr.toarray() + array([[1., 1., 0., 0., 0., 0.], + [0., 0., 1., 1., 0., 0.], + [0., 0., 0., 0., 1., 1.]]) + >>> Hc.toarray() + array([[1., 0., 1., 0., 1., 0.], + [0., 1., 0., 1., 0., 1.]]) + >>> c + array([16., 25., 28., 16., 40., 36.]) + """ + + dim_a = np.shape(a)[0] + dim_b = np.shape(b)[0] + + c = C.flatten() + jHr = np.arange(dim_a * dim_b) + iHr = np.repeat(np.arange(dim_a), dim_b) + jHc = np.arange(dim_a * dim_b) + iHc = np.tile(np.arange(dim_b), dim_a) + + Hr = sp.csc_matrix((np.ones(dim_a * dim_b), (iHr, jHr)), + shape=(dim_a, dim_a * dim_b)) + Hc = sp.csc_matrix((np.ones(dim_a * dim_b), (iHc, jHc)), + shape=(dim_b, dim_a * dim_b)) + + return Hr, Hc, c + + +def ot_next_gamma(phi, delta, HtH, Hty, c, active_index, current_gamma): + r""" This function computes the next value of gamma if a variable + will be added in next iteration of the regularization path + + We look for the largest value of gamma such that + the gradient of an inactive variable vanishes + .. math:: + \max_{i \in \bar{A}} \frac{h_i^T(H_A \phi - y)}{h_i^T H_A \delta - c_i} + where : + - A is the current active set + - h_i is the ith column of auxiliary matrix H + - H_A is the sub-matrix constructed by the columns of H + whose indices belong to the active set A + - c_i is the ith element of cost vector c + - y is the concatenation of source and target distribution + - :math:`\phi` is the intercept of the solutions in current iteration + - :math:`\delta` is the slope of the solutions in current iteration + Parameters + ---------- + phi : np.ndarray (|A|, ) + Intercept of the solutions in current iteration (t is piecewise linear) + delta : np.ndarray (|A|, ) + Slope of the solutions in current iteration (t is piecewise linear) + HtH : np.ndarray (dim_a * dim_b, dim_a * dim_b) + Matrix product of H^T H + Hty : np.ndarray (dim_a + dim_b, ) + Matrix product of H^T y + c: np.ndarray (dim_a * dim_b, ) + Flattened array of cost matrix C + active_index : list + Indices of active variables + current_gamma : float + Value of regularization coefficient at the start of current iteration + Returns + ------- + next_gamma : float + Value of gamma if a variable is added to active set in next iteration + next_active_index : int + Index of variable to be activated + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + M = (HtH[:, active_index].dot(phi) - Hty) / \ + (HtH[:, active_index].dot(delta) - c + 1e-16) + M[active_index] = 0 + M[M > (current_gamma - 1e-10 * current_gamma)] = 0 + return np.max(M), np.argmax(M) + + +def semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, HrHr, Hc, Hra, + c, active_index, current_gamma): + r""" This function computes the next value of gamma when a variable is + active in the regularization path of semi-relaxed UOT. + + By taking the Lagrangian form of the problem, we obtain a similar update + as the two-sided relaxed UOT + .. math:: + \max_{i \in \bar{A}} \frac{h_{r i}^T(H_{r A} \phi - a) + h_{c i}^T + \phi_u}{h_{r i}^T H_{r A} \delta + h_{c i} \delta_u - c_i} + where : + - A is the current active set + - h_{r i} is the ith column of the matrix H_r + - h_{c i} is the ith column of the matrix H_c + - H_{r A} is the sub-matrix constructed by the columns of H_r + whose indices belong to the active set A + - c_i is the ith element of cost vector c + - y is the concatenation of source and target distribution + - :math:`\phi` is the intercept of the solutions in current iteration + - :math:`\delta` is the slope of the solutions in current iteration + - :math:`\phi_u` is the intercept of Lagrange parameter in current + iteration + - :math:`\delta_u` is the slope of Lagrange parameter in current iteration + Parameters + ---------- + phi : np.ndarray (|A|, ) + Intercept of the solutions in current iteration (t is piecewise linear) + delta : np.ndarray (|A|, ) + Slope of the solutions in current iteration (t is piecewise linear) + phi_u : np.ndarray (dim_b, ) + Intercept of the Lagrange parameter in current iteration (also linear) + delta_u : np.ndarray (dim_b, ) + Slope of the Lagrange parameter in current iteration (also linear) + HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) + Matrix product of H_r^T H_r + Hc : np.ndarray (dim_b, dim_a * dim_b) + Matrix that computes the sum along the columns of transport plan T + Hra : np.ndarray (dim_a * dim_b, ) + Matrix product of H_r^T a + c: np.ndarray (dim_a * dim_b, ) + Flattened array of cost matrix C + active_index : list + Indices of active variables + current_gamma : float + Value of regularization coefficient at the start of current iteration + Returns + ------- + next_gamma : float + Value of gamma if a variable is added to active set in next iteration + next_active_index : int + Index of variable to be activated + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + M = (HrHr[:, active_index].dot(phi) - Hra + Hc.T.dot(phi_u)) / \ + (HrHr[:, active_index].dot(delta) - c + Hc.T.dot(delta_u) + 1e-16) + M[active_index] = 0 + M[M > (current_gamma - 1e-10 * current_gamma)] = 0 + return np.max(M), np.argmax(M) + + +def compute_next_removal(phi, delta, current_gamma): + r""" This function computes the next value of gamma if a variable + is removed in next iteration of regularization path + + We look for the largest value of gamma such that + an element of current solution vanishes + .. math:: + \max_{j \in A} \frac{\phi_j}{\delta_j} + where : + - A is the current active set + - phi_j is the jth element of the intercept of current solution + - delta_j is the jth elemnt of the slope of current solution + Parameters + ---------- + phi : np.ndarray (|A|, ) + Intercept of the solutions in current iteration (t is piecewise linear) + delta : np.ndarray (|A|, ) + Slope of the solutions in current iteration (t is piecewise linear) + current_gamma : float + Value of regularization coefficient at the start of current iteration + Returns + ------- + next_removal_gamma : float + Value of gamma if a variable is removed in next iteration + next_removal_index : int + Index of the variable to remove in next iteration + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + r_candidate = phi / (delta - 1e-16) + r_candidate[r_candidate >= (1 - 1e-8) * current_gamma] = 0 + return np.max(r_candidate), np.argmax(r_candidate) + + +def complement_schur(M_current, b, d, id_pop): + r""" This function computes the inverse of matrix in regularization path + using Schur complement + + Two cases may arise: Firstly one variable is added to the active set + .. math:: + M_{k+1}^{-1} = + \begin{bmatrix} + M_{k}^{-1} + s^{-1} M_{k}^{-1} b b^T M_{k}^{-1} & -s^{-1} \\ + - s^{-1} b^T M_{k}^{-1} & s^{-1} + \end{bmatrix} + where : + - :math:`M_k^{-1}` is the inverse of matrix in previous iteration and + :math:`M_k` is the upper left block matrix in Schur formulation + - b is the upper right block matrix in Schur formulation. In our case, + b is reduced to a column vector and b^T is the lower left block matrix + - s is the Schur complement, given by + :math:`s = d - b^T M_{k}^{-1} b` in our case + + Secondly, one variable is removed from the active set + .. math:: + M_{k+1}^{-1} = M^{-1}_{A_k \backslash q} - + \frac{r_{-q,q} r^{T}_{-q,q}}{r_{q,q}} + where : + - q is the index of column and row to delete + - :math:`M^{-1}_{A_k \backslash q}` is the previous inverse matrix + without qth column and qth row + - r_{-q,q} is the qth column of :math:`M^{-1}_{k}` without the qth element + - r_{q, q} is the element of qth column and qth row in :math:`M^{-1}_{k}` + Parameters + ---------- + M_current : np.ndarray (|A|-1, |A|-1) + Inverse matrix in previous iteration + b : np.ndarray (|A|-1, ) + Upper right matrix in Schur complement, a column vector in our case + d : float + Lower right matrix in Schur complement, a scalar in our case + id_pop + Index of the variable to be removed, equal to -1 + if none of the variables is deleted in current iteration + Returns + ------- + M : np.ndarray (|A|, |A|) + Inverse matrix needed in current iteration + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + if b is None: + b = M_current[id_pop, :] + b = np.delete(b, id_pop) + M_del = np.delete(M_current, id_pop, 0) + a = M_del[:, id_pop] + M_del = np.delete(M_del, id_pop, 1) + M = M_del - np.outer(a, b) / M_current[id_pop, id_pop] + else: + n = b.shape[0] + 1 + if np.shape(b)[0] == 0: + M = np.array([[0.5]]) + else: + X = M_current.dot(b) + s = d - b.T.dot(X) + M = np.zeros((n, n)) + M[:-1, :-1] = M_current + X.dot(X.T) / s + X_ravel = X.ravel() + M[-1, :-1] = -X_ravel / s + M[:-1, -1] = -X_ravel / s + M[-1, -1] = 1 / s + return M + + +def construct_augmented_H(active_index, m, Hc, HrHr): + r""" This function construct an augmented matrix for the first iteration of + semi-relaxed regularization path + + .. math:: + Augmented_H = + \begin{bmatrix} + 0 & H_{c A} \\ + H_{c A}^T & H_{r A}^T H_{r A} + \end{bmatrix} + where : + - H_{r A} is the sub-matrix constructed by the columns of H_r + whose indices belong to the active set A + - H_{c A} is the sub-matrix constructed by the columns of H_c + whose indices belong to the active set A + Parameters + ---------- + active_index : list + Indices of active variables + m : int + Length of the target distribution + Hc : np.ndarray (dim_b, dim_a * dim_b) + Matrix that computes the sum along the columns of transport plan T + HrHr : np.ndarray (dim_a * dim_b, dim_a * dim_b) + Matrix product of H_r^T H_r + Returns + ------- + H_augmented : np.ndarray (dim_b + |A|, dim_b + |A|) + Augmented matrix for the first iteration of the semi-relaxed + regularization path + """ + Hc_sub = Hc[:, active_index].toarray() + HrHr_sub = HrHr[:, active_index] + HrHr_sub = HrHr_sub[active_index, :].toarray() + H_augmented = np.block([[np.zeros((m, m)), Hc_sub], [Hc_sub.T, HrHr_sub]]) + return H_augmented + + +def fully_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, + itmax=50000): + r"""This function gives the regularization path of l2-penalized UOT problem + + The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: + \min_t \gamma c^T t + 0.5 * \|H t - y\|_2^2 + s.t. + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - y is the concatenation of vectors a and b, defined as y^T = [a^T b^T] + - H is a (dim_a + dim_b, dim_a * dim_b) metric matrix, + see [Chapel et al., 2021] for the design of H. The matrix product Ht + computes both the source marginal and the target marginal. + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + reg: float + l2-regularization coefficient + itmax: int + Maximum number of iteration + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Flattened vector of optimal transport matrix + t_list : list + List of solutions in regularization path + gamma_list : list + List of regularization coefficient in regularization path + Examples + -------- + >>> import ot + >>> import numpy as np + >>> n = 3 + >>> xs = np.array([1., 2., 3.]).reshape((n, 1)) + >>> xt = np.array([5., 6., 7.]).reshape((n, 1)) + >>> C = ot.dist(xs, xt) + >>> C /= C.max() + >>> a = np.array([0.2, 0.5, 0.3]) + >>> b = np.array([0.2, 0.5, 0.3]) + >>> t, _, _ = ot.regpath.fully_relaxed_path(a, b, C, 1e-4) + >>> t + array([1.99958333e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05, + 4.99938889e-01, 0.00000000e+00, 0.00000000e+00, 3.88888889e-05, + 2.99958333e-01]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + n = np.shape(a)[0] + m = np.shape(b)[0] + H, y, c = recast_ot_as_lasso(a, b, C) + HtH = H.T.dot(H) + Hty = H.T.dot(y) + n_iter = 1 + + # initialization + M0 = Hty / c + gamma_list = [np.max(M0)] + active_index = [np.argmax(M0)] + t_list = [np.zeros((n * m,))] + H_inv = np.array([[]]) + add_col = np.array([]) + id_pop = -1 + + while n_iter < itmax and gamma_list[-1] > reg: + H_inv = complement_schur(H_inv, add_col, 2., id_pop) + current_gamma = gamma_list[-1] + + # compute the intercept and slope of solutions in current iteration + # t = phi - gamma * delta + phi = H_inv.dot(Hty[active_index]) + delta = H_inv.dot(c[active_index]) + gamma, ik = ot_next_gamma(phi, delta, HtH, Hty, c, active_index, + current_gamma) + + # compute the next lambda when removing a point from the active set + alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma) + + # if the positivity constraint is violated, we remove id_pop + # from active set, otherwise we add ik to active set + if alt_gamma > gamma: + gamma = alt_gamma + else: + id_pop = -1 + + # compute the solution of current segment + tA = phi - gamma * delta + sol = np.zeros((n * m, )) + sol[active_index] = tA + + if id_pop != -1: + active_index.pop(id_pop) + add_col = None + else: + active_index.append(ik) + add_col = HtH[active_index[:-1], ik].toarray() + + gamma_list.append(gamma) + t_list.append(sol) + n_iter += 1 + + if itmax <= n_iter: + print('maximum iteration has been reached !') + + # correct the last solution and gamma + if len(t_list) > 1: + t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) * + (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2])) + t_list[-1] = t_final + gamma_list[-1] = reg + else: + gamma_list[-1] = reg + print('Regularization path does not exist !') + + return t_list[-1], t_list, gamma_list + + +def semi_relaxed_path(a: np.array, b: np.array, C: np.array, reg=1e-4, + itmax=50000): + r"""This function gives the regularization path of semi-relaxed + l2-UOT problem + + The problem to optimize is the Lasso reformulation of the l2-penalized UOT: + .. math:: + \min_t \gamma c^T t + 0.5 * \|H_r t - a\|_2^2 + s.t. + H_c t = b + t \geq 0 + where : + - c is a (dim_a * dim_b, ) metric cost vector (flattened version of C) + - :math:`\gamma = 1/\lambda` is the l2-regularization coefficient + - H_r is a (dim_a, dim_a * dim_b) metric matrix, + which computes the sum along the rows of transport plan T + - H_c is a (dim_b, dim_a * dim_b) metric matrix, + which computes the sum along the columns of transport plan T + - t is a (dim_a * dim_b, ) metric vector (flattened version of T) + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + reg: float (optional) + l2-regularization coefficient + itmax: int (optional) + Maximum number of iteration + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Flattened vector of optimal transport matrix + t_list : list + List of solutions in regularization path + gamma_list : list + List of regularization coefficient in regularization path + Examples + -------- + >>> import ot + >>> import numpy as np + >>> n = 3 + >>> xs = np.array([1., 2., 3.]).reshape((n, 1)) + >>> xt = np.array([5., 6., 7.]).reshape((n, 1)) + >>> C = ot.dist(xs, xt) + >>> C /= C.max() + >>> a = np.array([0.2, 0.5, 0.3]) + >>> b = np.array([0.2, 0.5, 0.3]) + >>> t, _, _ = ot.regpath.semi_relaxed_path(a, b, C, 1e-4) + >>> t + array([1.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05, + 4.99980556e-01, 0.00000000e+00, 0.00000000e+00, 1.94444444e-05, + 3.00000000e-01]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + n = np.shape(a)[0] + m = np.shape(b)[0] + Hr, Hc, c = recast_semi_relaxed_as_lasso(a, b, C) + Hra = Hr.T.dot(a) + HrHr = Hr.T.dot(Hr) + n_iter = 1 + active_index = [] + + # initialization + for j in range(np.shape(C)[1]): + i = np.argmin(C[:, j]) + active_index.append(i * m + j) + gamma_list = [] + t_list = [] + current_gamma = np.Inf + augmented_H0 = construct_augmented_H(active_index, m, Hc, HrHr) + add_col = np.array([]) + id_pop = -1 + + while n_iter < itmax and current_gamma > reg: + if n_iter == 1: + H_inv = np.linalg.inv(augmented_H0) + else: + H_inv = complement_schur(H_inv, add_col, 1., id_pop + m) + # compute the intercept and slope of solutions in current iteration + augmented_phi = H_inv.dot(np.concatenate((b, Hra[active_index]))) + augmented_delta = H_inv[:, m:].dot(c[active_index]) + phi = augmented_phi[m:] + delta = augmented_delta[m:] + phi_u = augmented_phi[0:m] + delta_u = augmented_delta[0:m] + gamma, ik = semi_relaxed_next_gamma(phi, delta, phi_u, delta_u, + HrHr, Hc, Hra, c, active_index, + current_gamma) + + # compute the next lambda when removing a point from the active set + alt_gamma, id_pop = compute_next_removal(phi, delta, current_gamma) + + # if the positivity constraint is violated, we remove id_pop + # from active set, otherwise we add ik to active set + if alt_gamma > gamma: + gamma = alt_gamma + else: + id_pop = -1 + + # compute the solution of current segment + tA = phi - gamma * delta + sol = np.zeros((n * m, )) + sol[active_index] = tA + if id_pop != -1: + active_index.pop(id_pop) + add_col = None + else: + active_index.append(ik) + add_col = np.concatenate((Hc.toarray()[:, ik], + HrHr.toarray()[active_index[:-1], ik])) + add_col = add_col[:, np.newaxis] + + gamma_list.append(gamma) + t_list.append(sol) + current_gamma = gamma + n_iter += 1 + + if itmax <= n_iter: + print('maximum iteration has been reached !') + + # correct the last solution and gamma + if len(t_list) > 1: + t_final = (t_list[-2] + (t_list[-1] - t_list[-2]) * + (reg - gamma_list[-2]) / (gamma_list[-1] - gamma_list[-2])) + t_list[-1] = t_final + gamma_list[-1] = reg + else: + gamma_list[-1] = reg + print('Regularization path does not exist !') + + return t_list[-1], t_list, gamma_list + + +def regularization_path(a: np.array, b: np.array, C: np.array, reg=1e-4, + semi_relaxed=False, itmax=50000): + r"""This function combines both the semi-relaxed and the fully-relaxed + regularization paths of l2-UOT problem + + Parameters + ---------- + a : np.ndarray (dim_a,) + Histogram of dimension dim_a + b : np.ndarray (dim_b,) + Histogram of dimension dim_b + C : np.ndarray, shape (dim_a, dim_b) + Cost matrix + reg: float (optional) + l2-regularization coefficient + semi_relaxed : bool (optional) + Give the semi-relaxed path if true + itmax: int (optional) + Maximum number of iteration + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Flattened vector of optimal transport matrix + t_list : list + List of solutions in regularization path + gamma_list : list + List of regularization coefficient in regularization path + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + if semi_relaxed: + t, t_list, gamma_list = semi_relaxed_path(a, b, C, reg=reg, + itmax=itmax) + else: + t, t_list, gamma_list = fully_relaxed_path(a, b, C, reg=reg, + itmax=itmax) + return t, t_list, gamma_list + + +def compute_transport_plan(gamma, gamma_list, Pi_list): + r""" Given the regularization path, this function computes the transport + plan for any value of gamma by the piecewise linearity of the path + + .. math:: + t(\gamma) = \phi(\gamma) - \gamma \delta(\gamma) + where : + - :math:`\gamma` is the regularization coefficient + - :math:`\phi(\gamma)` is the corresponding intercept + - :math:`\delta(\gamma)` is the corresponding slope + - t is a (dim_a * dim_b, ) vector (flattened version of transport matrix) + Parameters + ---------- + gamma : float + Regularization coefficient + gamma_list : list + List of regularization coefficients in regularization path + Pi_list : list + List of solutions in regularization path + Returns + ------- + t : np.ndarray (dim_a*dim_b, ) + Transport vector corresponding to the given value of gamma + Examples + -------- + >>> import ot + >>> import numpy as np + >>> n = 3 + >>> xs = np.array([1., 2., 3.]).reshape((n, 1)) + >>> xt = np.array([5., 6., 7.]).reshape((n, 1)) + >>> C = ot.dist(xs, xt) + >>> C /= C.max() + >>> a = np.array([0.2, 0.5, 0.3]) + >>> b = np.array([0.2, 0.5, 0.3]) + >>> t, pi_list, g_list = ot.regpath.regularization_path(a, b, C, reg=1e-4) + >>> gamma = 1 + >>> t2 = ot.regpath.compute_transport_plan(gamma, g_list, pi_list) + >>> t2 + array([0. , 0. , 0. , 0.19722222, 0.05555556, + 0. , 0. , 0.24722222, 0. ]) + + References + ---------- + [Chapel et al., 2021]: + Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. + """ + + if gamma >= gamma_list[0]: + Pi = Pi_list[0] + elif gamma <= gamma_list[-1]: + Pi = Pi_list[-1] + else: + idx = np.where(gamma <= np.array(gamma_list))[0][-1] + gamma_k0 = gamma_list[idx] + gamma_k1 = gamma_list[idx + 1] + pi_k0 = Pi_list[idx] + pi_k1 = Pi_list[idx + 1] + Pi = pi_k0 + (pi_k1 - pi_k0) * (gamma - gamma_k0) \ + / (gamma_k1 - gamma_k0) + return Pi diff --git a/test/test_regpath.py b/test/test_regpath.py new file mode 100644 index 0000000..967c27b --- /dev/null +++ b/test/test_regpath.py @@ -0,0 +1,64 @@ +"""Tests for module regularization path""" + +# Author: Haoran Wu +# +# License: MIT License + +import numpy as np +import ot + + +def test_fully_relaxed_path(): + + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + np.random.seed(0) + xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov) + + # source and target distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, + semi_relaxed=False) + + G = t.reshape((n_source, n_target)) + np.testing.assert_allclose(a, G.sum(1), atol=1e-05) + np.testing.assert_allclose(b, G.sum(0), atol=1e-05) + + +def test_semi_relaxed_path(): + + n_source = 50 # nb source samples (gaussian) + n_target = 40 # nb target samples (gaussian) + + mu = np.array([0, 0]) + cov = np.array([[1, 0], [0, 2]]) + + np.random.seed(0) + xs = ot.datasets.make_2D_samples_gauss(n_source, mu, cov) + xt = ot.datasets.make_2D_samples_gauss(n_target, mu, cov) + + # source and target distributions + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + t, _, _ = ot.regpath.regularization_path(a, b, M, reg=1e-8, + semi_relaxed=True) + + G = t.reshape((n_source, n_target)) + np.testing.assert_allclose(a, G.sum(1), atol=1e-05) + np.testing.assert_allclose(b, G.sum(0), atol=1e-10) -- cgit v1.2.3 From 6775a527f9d3c801f8cdd805d8f205b6a75551b9 Mon Sep 17 00:00:00 2001 From: Nicolas Courty Date: Tue, 2 Nov 2021 14:19:57 +0100 Subject: [MRG] Sliced and 1D Wasserstein distances : backend versions (#256) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update test/test_utils.py Co-authored-by: Alexandre Gramfort * Update ot/utils.py Co-authored-by: Alexandre Gramfort * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- README.md | 11 +- docs/source/readme.rst | 51 ++- .../backends/plot_sliced_wass_grad_flow_pytorch.py | 185 +++++++++++ examples/backends/plot_wass1d_torch.py | 152 +++++++++ examples/sliced-wasserstein/plot_variance.py | 2 +- ot/__init__.py | 5 +- ot/backend.py | 98 ++++++ ot/lp/__init__.py | 367 ++------------------- ot/lp/solver_1d.py | 367 +++++++++++++++++++++ ot/sliced.py | 181 ++++++++-- test/test_1d_solver.py | 85 +++++ test/test_backend.py | 36 ++ test/test_ot.py | 57 +--- test/test_sliced.py | 90 ++++- test/test_utils.py | 2 +- 15 files changed, 1244 insertions(+), 445 deletions(-) create mode 100644 examples/backends/plot_sliced_wass_grad_flow_pytorch.py create mode 100644 examples/backends/plot_wass1d_torch.py create mode 100644 ot/lp/solver_1d.py create mode 100644 test/test_1d_solver.py (limited to 'examples') diff --git a/README.md b/README.md index f0e5227..cfb9744 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). -* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/) arrays. POT provides the following Machine Learning related solvers: @@ -285,4 +285,11 @@ You can also post bug reports and feature requests in Github issues. Make sure t [33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 -[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). [Interpolating between optimal transport and MMD using Sinkhorn divergences](http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf). In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + +[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on +Machine Learning (pp. 4104-4113). PMLR. diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 82d3e6c..ee32e2b 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -24,7 +24,7 @@ POT provides the following generic OT solvers (links to examples): for regularized OT [7]. - Entropic regularization OT solver with `Sinkhorn Knopp Algorithm `__ - [2] , stabilized version [9] [10], greedy Sinkhorn [22] and + [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and `Screening Sinkhorn [26] `__. - Bregman projections for `Wasserstein @@ -54,6 +54,9 @@ POT provides the following generic OT solvers (links to examples): solver `__ for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +- `Stochastic solver of Gromov + Wasserstein `__ + for large-scale problem with any loss functions [33] - Non regularized `free support Wasserstein barycenters `__ [20]. @@ -137,19 +140,12 @@ following Python modules: - Numpy (>=1.16) - Scipy (>=1.0) -- Cython (>=0.23) (build only, not necessary when installing wheels - from pip or conda) +- Cython (>=0.23) (build only, not necessary when installing from pip + or conda) Pip installation ^^^^^^^^^^^^^^^^ -Note that due to a limitation of pip, ``cython`` and ``numpy`` need to -be installed prior to installing POT. This can be done easily with - -.. code:: console - - pip install numpy cython - You can install the toolbox through PyPI with: .. code:: console @@ -183,7 +179,8 @@ without errors: import ot -Note that for easier access the module is name ot instead of pot. +Note that for easier access the module is named ``ot`` instead of +``pot``. Dependencies ~~~~~~~~~~~~ @@ -222,7 +219,7 @@ Short examples .. code:: python - # a and b are 1D histograms (sum to 1 and positive) + # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix Wd = ot.emd2(a, b, M) # exact linear program Wd_reg = ot.sinkhorn2(a, b, M, reg) # entropic regularized OT @@ -232,7 +229,7 @@ Short examples .. code:: python - # a and b are 1D histograms (sum to 1 and positive) + # a,b are 1D histograms (sum to 1 and positive) # M is the ground cost matrix T = ot.emd(a, b, M) # exact linear program T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT @@ -287,6 +284,10 @@ The contributors to this library are - `Ievgen Redko `__ (Laplacian DA, JCPOT) - `Adrien Corenflos `__ (Sliced Wasserstein Distance) +- `Tanguy Kerdoncuff `__ (Sampled Gromov + Wasserstein) +- `Minhui Huang `__ (Projection Robust + Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various @@ -476,6 +477,30 @@ of measures `__, Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 +[32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate +Descent Method for Computing the Projection Robust Wasserstein +Distance `__, +Proceedings of the 38th International Conference on Machine Learning +(ICML). + +[33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov +Wasserstein `__, +Machine Learning Journal (MJL), 2021 + +[34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., +& Peyré, G. (2019, April). `Interpolating between optimal transport and +MMD using Sinkhorn +divergences `__. +In The 22nd International Conference on Artificial Intelligence and +Statistics (pp. 2681-2690). PMLR. + +[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., +Koyejo, S., ... & Schwing, A. G. (2019). `Max-sliced wasserstein +distance and its use for +gans `__. +In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern +Recognition (pp. 10648-10656). + .. |PyPI version| image:: https://badge.fury.io/py/POT.svg :target: https://badge.fury.io/py/POT .. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg diff --git a/examples/backends/plot_sliced_wass_grad_flow_pytorch.py b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py new file mode 100644 index 0000000..05b9952 --- /dev/null +++ b/examples/backends/plot_sliced_wass_grad_flow_pytorch.py @@ -0,0 +1,185 @@ +r""" +================================= +Sliced Wasserstein barycenter and gradient flow with PyTorch +================================= + +In this exemple we use the pytorch backend to optimize the sliced Wasserstein +loss between two empirical distributions [31]. + +In the first example one we perform a +gradient flow on the support of a distribution that minimize the sliced +Wassersein distance as poposed in [36]. + +In the second exemple we optimize with a gradient descent the sliced +Wasserstein barycenter between two distributions as in [31]. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of +measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions. In International Conference on +Machine Learning (pp. 4104-4113). PMLR. + + +""" +# Author: Rémi Flamary +# +# License: MIT License + + +# %% +# Loading the data + + +import numpy as np +import matplotlib.pylab as pl +import torch +import ot +import matplotlib.animation as animation + +I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2] +I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2] + +sz = I2.shape[0] +XX, YY = np.meshgrid(np.arange(sz), np.arange(sz)) + +x1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0 +x2 = np.stack((XX[I2 == 0] + 60, -YY[I2 == 0] + 32), 1) * 1.0 +x3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0 + +pl.figure(1, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) + +# %% +# Sliced Wasserstein gradient flow with Pytorch +# --------------------------------------------- + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x1_torch = torch.tensor(x1).to(device=device).requires_grad_(True) +x2_torch = torch.tensor(x2).to(device=device) + + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, x1.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +for i in range(nb_iter_max): + + loss = ot.sliced_wasserstein_distance(x1_torch, x2_torch, n_projections=20, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = x1_torch.grad + x1_torch -= grad * lr / (1 + i / 5e1) # step + x1_torch.grad.zero_() + x_all[i, :, :] = x1_torch.clone().detach().cpu().numpy() + +xb = x1_torch.clone().detach().cpu().numpy() + +pl.figure(2, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0], xb[:, 1], alpha=0.5, label='$\mu^{(100)}$') +pl.title('Sliced Wasserstein gradient flow') +pl.legend() +ax = pl.axis() + +# %% +# Animate trajectories of the gradient flow along iteration +# ------------------------------------------------------- + +pl.figure(3, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0], x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein gradient flow Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) + +# %% +# Compute the Sliced Wasserstein Barycenter +# +x1_torch = torch.tensor(x1).to(device=device) +x3_torch = torch.tensor(x3).to(device=device) +xbinit = np.random.randn(500, 2) * 10 + 16 +xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True) + +lr = 1e3 +nb_iter_max = 100 + +x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2)) + +loss_iter = [] + +# generator for random permutations +gen = torch.Generator() +gen.manual_seed(42) + +alpha = 0.5 + +for i in range(nb_iter_max): + + loss = alpha * ot.sliced_wasserstein_distance(xbary_torch, x3_torch, n_projections=50, seed=gen) \ + + (1 - alpha) * ot.sliced_wasserstein_distance(xbary_torch, x1_torch, n_projections=50, seed=gen) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = xbary_torch.grad + xbary_torch -= grad * lr # / (1 + i / 5e1) # step + xbary_torch.grad.zero_() + x_all[i, :, :] = xbary_torch.clone().detach().cpu().numpy() + +xb = xbary_torch.clone().detach().cpu().numpy() + +pl.figure(4, (8, 4)) +pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu$') +pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') +pl.scatter(xb[:, 0] + 30, xb[:, 1], alpha=0.5, label='Barycenter') +pl.title('Sliced Wasserstein barycenter') +pl.legend() +ax = pl.axis() + + +# %% +# Animate trajectories of the barycenter along gradient descent +# ------------------------------------------------------- + +pl.figure(5, (8, 4)) + + +def _update_plot(i): + pl.clf() + pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5, label='$\mu^{(0)}$') + pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5, label=r'$\nu$') + pl.scatter(x_all[i, :, 0] + 30, x_all[i, :, 1], alpha=0.5, label='$\mu^{(100)}$') + pl.title('Sliced Wasserstein barycenter Iter. {}'.format(i)) + pl.axis(ax) + return 1 + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, nb_iter_max, interval=100, repeat_delay=2000) diff --git a/examples/backends/plot_wass1d_torch.py b/examples/backends/plot_wass1d_torch.py new file mode 100644 index 0000000..0abdd6d --- /dev/null +++ b/examples/backends/plot_wass1d_torch.py @@ -0,0 +1,152 @@ +r""" +================================= +Wasserstein 1D with PyTorch +================================= + +In this small example, we consider the following minization problem: + +.. math:: + \mu^* = \min_\mu W(\mu,\nu) + +where :math:`\nu` is a reference 1D measure. The problem is handled +by a projected gradient descent method, where the gradient is computed +by pyTorch automatic differentiation. The projection on the simplex +ensures that the iterate will remain on the probability simplex. + +This example illustrates both `wasserstein_1d` function and backend use within +the POT framework. +""" +# Author: Nicolas Courty +# Rémi Flamary +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import matplotlib as mpl +import torch + +from ot.lp import wasserstein_1d +from ot.datasets import make_1D_gauss as gauss +from ot.utils import proj_simplex + +red = np.array(mpl.colors.to_rgb('red')) +blue = np.array(mpl.colors.to_rgb('blue')) + + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# enforce sum to one on the support +a = a / a.sum() +b = b / b.sum() + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device).requires_grad_(True) +b_torch = torch.tensor(b).to(device=device) + +lr = 1e-6 +nb_iter_max = 800 + +loss_iter = [] + +pl.figure(1, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a_torch.grad + a_torch -= a_torch.grad * lr # step + a_torch.grad.zero_() + a_torch.data = proj_simplex(a_torch) # projection onto the simplex + + # plot one curve every 10 iterations + if i % 10 == 0: + mix = float(i) / nb_iter_max + pl.plot(x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red) + +pl.legend() +pl.title('Distribution along the iterations of the projected gradient descent') +pl.show() + +pl.figure(2) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() + +# %% +# Wasserstein barycenter +# --------- +# In this example, we consider the following Wasserstein barycenter problem +# $$ \\eta^* = \\min_\\eta\;\;\; (1-t)W(\\mu,\\eta) + tW(\\eta,\\nu)$$ +# where :math:`\\mu` and :math:`\\nu` are reference 1D measures, and :math:`t` +# is a parameter :math:`\in [0,1]`. The problem is handled by a project gradient +# descent method, where the gradient is computed by pyTorch automatic differentiation. +# The projection on the simplex ensures that the iterate will remain on the +# probability simplex. +# +# This example illustrates both `wasserstein_1d` function and backend use within the +# POT framework. + + +device = "cuda" if torch.cuda.is_available() else "cpu" + +# use pyTorch for our data +x_torch = torch.tensor(x).to(device=device) +a_torch = torch.tensor(a).to(device=device) +b_torch = torch.tensor(b).to(device=device) +bary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True) + + +lr = 1e-6 +nb_iter_max = 2000 + +loss_iter = [] + +# instant of the interpolation +t = 0.5 + +for i in range(nb_iter_max): + # Compute the Wasserstein 1D with torch backend + loss = (1 - t) * wasserstein_1d(x_torch, x_torch, a_torch.detach(), bary_torch, p=2) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2) + # record the corresponding loss value + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = bary_torch.grad + bary_torch -= bary_torch.grad * lr # step + bary_torch.grad.zero_() + bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex + +pl.figure(3, figsize=(8, 4)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.plot(x, bary_torch.clone().detach().cpu().numpy(), c='green', label='W barycenter') +pl.legend() +pl.title('Wasserstein barycenter computed by gradient descent') +pl.show() + +pl.figure(4) +pl.plot(range(nb_iter_max), loss_iter, lw=3) +pl.title('Evolution of the loss along iterations', fontsize=16) +pl.show() diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 27df656..7d73907 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -63,7 +63,7 @@ res = np.empty((n_seed, 25)) # %% Compute statistics for seed in range(n_seed): for i, n_projections in enumerate(n_projections_arr): - res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) + res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed=seed) res_mean = np.mean(res, axis=0) res_std = np.std(res, axis=0) diff --git a/ot/__init__.py b/ot/__init__.py index 5bd4bab..f20332c 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -42,7 +42,7 @@ from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import sliced_wasserstein_distance +from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance # utils functions from .utils import dist, unif, tic, toc, toq @@ -51,8 +51,9 @@ __version__ = "0.8.0dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', + 'emd2_1d', 'wasserstein_1d', 'backend', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'max_sliced_wasserstein_distance', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/backend.py b/ot/backend.py index 358297c..d3df44c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -103,6 +103,8 @@ class Backend(): __name__ = None __type__ = None + rng_ = None + def __str__(self): return self.__name__ @@ -540,6 +542,36 @@ class Backend(): """ raise NotImplementedError() + def seed(self, seed=None): + r""" + Sets the seed for the random generator. + + This function follows the api from :any:`numpy.random.seed` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html + """ + raise NotImplementedError() + + def rand(self, *size, type_as=None): + r""" + Generate uniform random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + + def randn(self, *size, type_as=None): + r""" + Generate normal Gaussian random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): r""" Creates a sparse tensor in COOrdinate format. @@ -632,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + rng_ = np.random.RandomState() + def to_numpy(self, a): return a @@ -793,6 +827,16 @@ class NumpyBackend(Backend): def reshape(self, a, shape): return np.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_.seed(seed) + + def rand(self, *size, type_as=None): + return self.rng_.rand(*size) + + def randn(self, *size, type_as=None): + return self.rng_.randn(*size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return coo_matrix((data, (rows, cols)), shape=shape) @@ -845,6 +889,11 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + rng_ = None + + def __init__(self): + self.rng_ = jax.random.PRNGKey(42) + def to_numpy(self, a): return np.array(a) @@ -1010,6 +1059,24 @@ class JaxBackend(Backend): def reshape(self, a, shape): return jnp.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_ = jax.random.PRNGKey(seed) + + def rand(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.uniform(subkey, shape=size) + + def randn(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.normal(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.normal(subkey, shape=size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): # Currently, JAX does not support sparse matrices data = self.to_numpy(data) @@ -1064,8 +1131,13 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + rng_ = None + def __init__(self): + self.rng_ = torch.Generator() + self.rng_.seed() + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1102,12 +1174,16 @@ class TorchBackend(Backend): return res def zeros(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.zeros(shape) else: return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) def ones(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.ones(shape) else: @@ -1120,6 +1196,8 @@ class TorchBackend(Backend): return torch.arange(start, stop, step, device=type_as.device) def full(self, shape, fill_value, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.full(shape, fill_value) else: @@ -1293,6 +1371,26 @@ class TorchBackend(Backend): def reshape(self, a, shape): return torch.reshape(a, shape) + def seed(self, seed=None): + if isinstance(seed, int): + self.rng_.manual_seed(seed) + elif isinstance(seed, torch.Generator): + self.rng_ = seed + else: + raise ValueError("Non compatible seed : {}".format(seed)) + + def rand(self, *size, type_as=None): + if type_as is not None: + return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + else: + return torch.rand(size=size, generator=self.rng_) + + def randn(self, *size, type_as=None): + if type_as is not None: + return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + else: + return torch.randn(size=size, generator=self.rng_) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 4e95ccf..2c18a88 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -13,20 +13,23 @@ import multiprocessing import sys import numpy as np -from scipy.sparse import coo_matrix import warnings from . import cvx from .cvx import barycenter + # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted +from .solver_1d import emd_1d, emd2_1d, wasserstein_1d + from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend -__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', +__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d'] + def check_number_threads(numThreads): """Checks whether or not the requested number of threads has a valid value. @@ -115,10 +118,10 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): .. warning:: This function is necessary because the C++ solver in emd_c - discards all samples in the distributions with - zeros weights. This means that while the primal variable (transport + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport matrix) is exact, the solver only returns feasible dual potentials - on the samples with weights different from zero. + on the samples with weights different from zero. First we compute the constraints violations: @@ -215,26 +218,26 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): format .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. + from all compatible backends. Uses the algorithm proposed in [1]_ Parameters ---------- - a : (ns,) array-like, float + a : (ns,) array-like, float Source histogram (uniform weight if empty list) - b : (nt,) array-like, float - Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float - Loss matrix (c-order array in numpy with type float64) - numItermax : int, optional (default=100000) + b : (nt,) array-like, float + Target histogram (uniform weight if empty list) + M : (ns,nt) array-like, float + Loss matrix (c-order array in numpy with type float64) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization - algorithm if it has not converged. - log: bool, optional (default=False) - If True, returns a dictionary containing the cost and dual variables. - Otherwise returns only the optimal transportation matrix. + algorithm if it has not converged. + log: bool, optional (default=False) + If True, returns a dictionary containing the cost and dual variables. + Otherwise returns only the optimal transportation matrix. center_dual: boolean, optional (default=True) - If True, centers the dual potential using function + If True, centers the dual potential using function :ref:`center_ot_dual`. numThreads: int or "max", optional (default=1, i.e. OpenMP is not used) If compiled with OpenMP, chooses the number of threads to parallelize. @@ -242,9 +245,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): Returns ------- - gamma: array-like, shape (ns, nt) + gamma: array-like, shape (ns, nt) Optimal transportation matrix for the given - parameters + parameters log: dict, optional If input log is true, a dictionary containing the cost and dual variables and exit status @@ -277,10 +280,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): regularized OT""" # convert to numpy if list - a, b, M = list_to_array(a, b, M) + a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M - nx = get_backend(M0, a0, b0) + nx = get_backend(M0, a0, b0) # convert to numpy M = nx.to_numpy(M) @@ -302,9 +305,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1): "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass - np.testing.assert_almost_equal(a.sum(0), - b.sum(0), err_msg='a and b vector must have the same sum') - b=b*a.sum()/b.sum() + np.testing.assert_almost_equal(a.sum(0), + b.sum(0), err_msg='a and b vector must have the same sum') + b = b * a.sum() / b.sum() asel = a != 0 bsel = b != 0 @@ -415,10 +418,10 @@ def emd2(a, b, M, processes=1, ot.bregman.sinkhorn : Entropic regularized OT ot.optim.cg : General regularized OT""" - a, b, M = list_to_array(a, b, M) + a, b, M = list_to_array(a, b, M) a0, b0, M0 = a, b, M - nx = get_backend(M0, a0, b0) + nx = get_backend(M0, a0, b0) # convert to numpy M = nx.to_numpy(M) @@ -427,7 +430,7 @@ def emd2(a, b, M, processes=1, a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order= 'C') + M = np.asarray(M, dtype=np.float64, order='C') # if empty array given then use uniform distributions if len(a) == 0: @@ -463,8 +466,8 @@ def emd2(a, b, M, processes=1, log['v'] = nx.from_numpy(v, type_as=b0) log['warning'] = result_code_string log['result_code'] = result_code - cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0,b0, M0), (log['u'], log['v'], G)) + cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), + (a0, b0, M0), (log['u'], log['v'], G)) return [cost, log] else: def f(b): @@ -479,8 +482,8 @@ def emd2(a, b, M, processes=1, G = nx.from_numpy(G, type_as=M0) cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0), - (a0,b0, M0), (nx.from_numpy(u, type_as=a0), - nx.from_numpy(v, type_as=b0),G)) + (a0, b0, M0), (nx.from_numpy(u, type_as=a0), + nx.from_numpy(v, type_as=b0), G)) check_result(result_code) return cost @@ -603,305 +606,3 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None return X, log_dict else: return X - - -def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the OT matrix - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - log: boolean, optional (default=False) - If True, returns a dictionary containing the cost. - Otherwise returns only the optimal transportation matrix. - - Returns - ------- - gamma: (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log: dict - If input log is True, a dictionary containing the cost - - - Examples - -------- - - Simple example with obvious solution. The function emd_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd_1d(x_a, x_b, a, b) - array([[0. , 0.5], - [0.5, 0. ]]) - >>> ot.emd_1d(x_a, x_b) - array([[0. , 0.5], - [0.5, 0. ]]) - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd : EMD for multidimensional distributions - ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the - transportation matrix) - """ - a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) - nx = get_backend(x_a, x_b) - - assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ - "emd_1d should only be used with monodimensional data" - - # if empty array given then use uniform distributions - if a is None or a.ndim == 0 or len(a) == 0: - a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] - if b is None or b.ndim == 0 or len(b) == 0: - b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] - - # ensure that same mass - np.testing.assert_almost_equal( - nx.sum(a, axis=0), - nx.sum(b, axis=0), - err_msg='a and b vector must have the same sum' - ) - b = b * nx.sum(a) / nx.sum(b) - - x_a_1d = nx.reshape(x_a, (-1,)) - x_b_1d = nx.reshape(x_b, (-1,)) - perm_a = nx.argsort(x_a_1d) - perm_b = nx.argsort(x_b_1d) - - G_sorted, indices, cost = emd_1d_sorted( - nx.to_numpy(a[perm_a]), - nx.to_numpy(b[perm_b]), - nx.to_numpy(x_a_1d[perm_a]), - nx.to_numpy(x_b_1d[perm_b]), - metric=metric, p=p - ) - - G = nx.coo_matrix( - G_sorted, - perm_a[indices[:, 0]], - perm_b[indices[:, 1]], - shape=(a.shape[0], b.shape[0]), - type_as=x_a - ) - if dense: - G = nx.todense(G) - elif str(nx) == "jax": - warnings.warn("JAX does not support sparse matrices, converting to dense") - if log: - log = {'cost': cost} - return G, log - return G - - -def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - log=False): - r"""Solves the Earth Movers distance problem between 1d measures and returns - the loss - - - .. math:: - \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - where : - - - d is the metric - - x_a and x_b are the samples - - a and b are the sample weights - - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. - p: float, optional (default=1.0) - The p-norm to apply for if metric='minkowski' - dense: boolean, optional (default=True) - If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). - Otherwise returns a sparse representation using scipy's `coo_matrix` - format. Only used if log is set to True. Due to implementation details, - this function runs faster when dense is set to False. - log: boolean, optional (default=False) - If True, returns a dictionary containing the transportation matrix. - Otherwise returns only the loss. - - Returns - ------- - loss: float - Cost associated to the optimal transportation - log: dict - If input log is True, a dictionary containing the Optimal transportation - matrix for the given parameters - - - Examples - -------- - - Simple example with obvious solution. The function emd2_1d accepts lists and - performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.emd2_1d(x_a, x_b, a, b) - 0.5 - >>> ot.emd2_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd2 : EMD for multidimensional distributions - ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix - instead of the cost) - """ - # If we do not return G (log==False), then we should not to cast it to dense - # (useless overhead) - G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, - dense=dense and log, log=True) - cost = log_emd['cost'] - if log: - log_emd = {'G': G} - return cost, log_emd - return cost - - -def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.): - r"""Solves the p-Wasserstein distance problem between 1d measures and returns - the distance - - .. math:: - \min_\gamma \left( \sum_i \sum_j \gamma_{ij} \|x_a[i] - x_b[j]\|^p \right)^{1/p} - - s.t. \gamma 1 = a, - \gamma^T 1= b, - \gamma\geq 0 - - where : - - - x_a and x_b are the samples - - a and b are the sample weights - - Uses the algorithm detailed in [1]_ - - Parameters - ---------- - x_a : (ns,) or (ns, 1) ndarray, float64 - Source dirac locations (on the real line) - x_b : (nt,) or (ns, 1) ndarray, float64 - Target dirac locations (on the real line) - a : (ns,) ndarray, float64, optional - Source histogram (default is uniform weight) - b : (nt,) ndarray, float64, optional - Target histogram (default is uniform weight) - p: float, optional (default=1.0) - The order of the p-Wasserstein distance to be computed - - Returns - ------- - dist: float - p-Wasserstein distance - - - Examples - -------- - - Simple example with obvious solution. The function wasserstein_1d accepts - lists and performs automatic conversion to numpy arrays - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> x_a = [2., 0.] - >>> x_b = [0., 3.] - >>> ot.wasserstein_1d(x_a, x_b, a, b) - 0.5 - >>> ot.wasserstein_1d(x_a, x_b) - 0.5 - - References - ---------- - - .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal - Transport", 2018. - - See Also - -------- - ot.lp.emd_1d : EMD for 1d distributions - """ - cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p, - dense=False, log=False) - return np.power(cost_emd, 1. / p) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py new file mode 100644 index 0000000..42554aa --- /dev/null +++ b/ot/lp/solver_1d.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- +""" +Exact solvers for the 1D Wasserstein distance using cvxopt +""" + +# Author: Remi Flamary +# Author: Nicolas Courty +# +# License: MIT License + +import numpy as np +import warnings + +from .emd_wrap import emd_1d_sorted +from ..backend import get_backend +from ..utils import list_to_array + + +def quantile_function(qs, cws, xs): + r""" Computes the quantile function of an empirical distribution + + Parameters + ---------- + qs: array-like, shape (n,) + Quantiles at which the quantile function is evaluated + cws: array-like, shape (m, ...) + cumulative weights of the 1D empirical distribution, if batched, must be similar to xs + xs: array-like, shape (n, ...) + locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions + + Returns + ------- + q: array-like, shape (..., n) + The quantiles of the distribution + """ + nx = get_backend(qs, cws) + n = xs.shape[0] + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + cws = cws.T.contiguous() + qs = qs.T.contiguous() + else: + cws = cws.T + qs = qs.T + idx = nx.searchsorted(cws, qs).T + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + + +def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True): + r""" + Computes the 1 dimensional OT loss [15] between two (batched) empirical + distributions + + .. math: + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + + It is formally the p-Wasserstein distance raised to the power p. + We do so in a vectorized way by first building the individual quantile functions then integrating them. + + This function should be preferred to `emd_1d` whenever the backend is + different to numpy, and when gradients over + either sample positions or weights are required. + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + + Returns + ------- + cost: float/array-like, shape (...) + the batched EMD + + References + ---------- + .. [15] Peyré, G., & Cuturi, M. (2018). Computational Optimal Transport. + + """ + + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cumweights = nx.cumsum(u_weights, 0) + v_cumweights = nx.cumsum(v_weights, 0) + + qs = nx.sort(nx.concatenate((u_cumweights, v_cumweights), 0), 0) + u_quantiles = quantile_function(qs, u_cumweights, u_values) + v_quantiles = quantile_function(qs, v_cumweights, v_values) + qs = nx.zero_pad(qs, pad_width=[(1, 0)] + (qs.ndim - 1) * [(0, 0)]) + delta = qs[1:, ...] - qs[:-1, ...] + diff_quantiles = nx.abs(u_quantiles - v_quantiles) + + if p == 1: + return nx.sum(delta * nx.abs(diff_quantiles), axis=0) + return nx.sum(delta * nx.power(diff_quantiles, p), axis=0) + + +def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the OT matrix + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + log: boolean, optional (default=False) + If True, returns a dictionary containing the cost. + Otherwise returns only the optimal transportation matrix. + + Returns + ------- + gamma: (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log: dict + If input log is True, a dictionary containing the cost + + + Examples + -------- + + Simple example with obvious solution. The function emd_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd_1d(x_a, x_b, a, b) + array([[0. , 0.5], + [0.5, 0. ]]) + >>> ot.emd_1d(x_a, x_b) + array([[0. , 0.5], + [0.5, 0. ]]) + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd : EMD for multidimensional distributions + ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the + transportation matrix) + """ + a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + nx = get_backend(x_a, x_b) + + assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ + "emd_1d should only be used with monodimensional data" + + # if empty array given then use uniform distributions + if a is None or a.ndim == 0 or len(a) == 0: + a = nx.ones((x_a.shape[0],), type_as=x_a) / x_a.shape[0] + if b is None or b.ndim == 0 or len(b) == 0: + b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0] + + # ensure that same mass + np.testing.assert_almost_equal( + nx.sum(a, axis=0), + nx.sum(b, axis=0), + err_msg='a and b vector must have the same sum' + ) + b = b * nx.sum(a) / nx.sum(b) + + x_a_1d = nx.reshape(x_a, (-1,)) + x_b_1d = nx.reshape(x_b, (-1,)) + perm_a = nx.argsort(x_a_1d) + perm_b = nx.argsort(x_b_1d) + + G_sorted, indices, cost = emd_1d_sorted( + nx.to_numpy(a[perm_a]), + nx.to_numpy(b[perm_b]), + nx.to_numpy(x_a_1d[perm_a]), + nx.to_numpy(x_b_1d[perm_b]), + metric=metric, p=p + ) + + G = nx.coo_matrix( + G_sorted, + perm_a[indices[:, 0]], + perm_b[indices[:, 1]], + shape=(a.shape[0], b.shape[0]), + type_as=x_a + ) + if dense: + G = nx.todense(G) + elif str(nx) == "jax": + warnings.warn("JAX does not support sparse matrices, converting to dense") + if log: + log = {'cost': cost} + return G, log + return G + + +def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, + log=False): + r"""Solves the Earth Movers distance problem between 1d measures and returns + the loss + + + .. math:: + \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j]) + + s.t. \gamma 1 = a, + \gamma^T 1= b, + \gamma\geq 0 + where : + + - d is the metric + - x_a and x_b are the samples + - a and b are the sample weights + + When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + + Uses the algorithm detailed in [1]_ + + Parameters + ---------- + x_a : (ns,) or (ns, 1) ndarray, float64 + Source dirac locations (on the real line) + x_b : (nt,) or (ns, 1) ndarray, float64 + Target dirac locations (on the real line) + a : (ns,) ndarray, float64, optional + Source histogram (default is uniform weight) + b : (nt,) ndarray, float64, optional + Target histogram (default is uniform weight) + metric: str, optional (default='sqeuclidean') + Metric to be used. Only strings listed in :func:`ot.dist` are accepted. + Due to implementation details, this function runs faster when + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics + are used. + p: float, optional (default=1.0) + The p-norm to apply for if metric='minkowski' + dense: boolean, optional (default=True) + If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). + Otherwise returns a sparse representation using scipy's `coo_matrix` + format. Only used if log is set to True. Due to implementation details, + this function runs faster when dense is set to False. + log: boolean, optional (default=False) + If True, returns a dictionary containing the transportation matrix. + Otherwise returns only the loss. + + Returns + ------- + loss: float + Cost associated to the optimal transportation + log: dict + If input log is True, a dictionary containing the Optimal transportation + matrix for the given parameters + + + Examples + -------- + + Simple example with obvious solution. The function emd2_1d accepts lists and + performs automatic conversion to numpy arrays + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> x_a = [2., 0.] + >>> x_b = [0., 3.] + >>> ot.emd2_1d(x_a, x_b, a, b) + 0.5 + >>> ot.emd2_1d(x_a, x_b) + 0.5 + + References + ---------- + + .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + See Also + -------- + ot.lp.emd2 : EMD for multidimensional distributions + ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix + instead of the cost) + """ + # If we do not return G (log==False), then we should not to cast it to dense + # (useless overhead) + G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p, + dense=dense and log, log=True) + cost = log_emd['cost'] + if log: + log_emd = {'G': G} + return cost, log_emd + return cost diff --git a/ot/sliced.py b/ot/sliced.py index 4792576..d3dc3f2 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -1,61 +1,73 @@ """ -Sliced Wasserstein Distance. +Sliced OT Distances """ # Author: Adrien Corenflos +# Nicolas Courty +# Rémi Flamary # # License: MIT License import numpy as np +from .backend import get_backend, NumpyBackend +from .utils import list_to_array -def get_random_projections(n_projections, d, seed=None): +def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): r""" Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})` Parameters ---------- - n_projections : int - number of samples requested d : int dimension of the space + n_projections : int + number of samples requested seed: int or RandomState, optional Seed used for numpy random number generator + backend: + Backend to ue for random generation Returns ------- - out: ndarray, shape (n_projections, d) + out: ndarray, shape (d, n_projections) The uniform unit vectors on the sphere Examples -------- >>> n_projections = 100 >>> d = 5 - >>> projs = get_random_projections(n_projections, d) - >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE + >>> projs = get_random_projections(d, n_projections) + >>> np.allclose(np.sum(np.square(projs), 0), 1.) # doctest: +NORMALIZE_WHITESPACE True """ - if not isinstance(seed, np.random.RandomState): - random_state = np.random.RandomState(seed) + if backend is None: + nx = NumpyBackend() + else: + nx = backend + + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + projections = seed.randn(d, n_projections) else: - random_state = seed + if seed is not None: + nx.seed(seed) + projections = nx.randn(d, n_projections, type_as=type_as) - projections = random_state.normal(0., 1., [n_projections, d]) - norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) - projections = projections / norm + projections = projections / nx.sqrt(nx.sum(projections**2, 0, keepdims=True)) return projections -def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): +def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): r""" - Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance + Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance .. math:: - \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}} + \mathcal{SWD}_p(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_p^p(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{p}} where : @@ -74,8 +86,12 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed samples weights in the target domain n_projections : int, optional Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) seed: int or RandomState or None, optional - Seed used for numpy random number generator + Seed used for random number generator log: bool, optional if True, sliced_wasserstein_distance returns the projections used and their associated EMD. @@ -100,10 +116,18 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 """ - from .lp import emd2_1d + from .lp import wasserstein_1d - X_s = np.asanyarray(X_s) - X_t = np.asanyarray(X_t) + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) n = X_s.shape[0] m = X_t.shape[0] @@ -114,31 +138,120 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed X_t.shape[1])) if a is None: - a = np.full(n, 1 / n) + a = nx.full(n, 1 / n) if b is None: - b = np.full(m, 1 / m) + b = nx.full(m, 1 / m) d = X_s.shape[1] - projections = get_random_projections(n_projections, d, seed) + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) + + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) - X_s_projections = np.dot(projections, X_s.T) - X_t_projections = np.dot(projections, X_t.T) + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) + res = (nx.sum(projected_emd) / n_projections) ** (1.0 / p) if log: - projected_emd = np.empty(n_projections) + return res, {"projections": projections, "projected_emds": projected_emd} + return res + + +def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, + projections=None, seed=None, log=False): + r""" + Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance + + .. math:: + \mathcal{Max-SWD}_p(\mu, \nu) = \underset{\theta _in + \mathcal{U}(\mathbb{S}^{d-1})}{\max} [\mathcal{W}_p^p(\theta_\# + \mu, \theta_\# \nu)]^{\frac{1}{p}} + + where : + + - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + + .. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). Max-sliced wasserstein distance and its use for gans. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). + """ + from .lp import wasserstein_1d + + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) else: - projected_emd = None + nx = get_backend(X_s, X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + + if a is None: + a = nx.full(n, 1 / n) + if b is None: + b = nx.full(m, 1 / m) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s) - res = 0. + X_s_projections = nx.dot(X_s, projections) + X_t_projections = nx.dot(X_t, projections) - for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)): - emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) - if projected_emd is not None: - projected_emd[i] = emd - res += emd + projected_emd = wasserstein_1d(X_s_projections, X_t_projections, a, b, p=p) - res = (res / n_projections) ** 0.5 + res = nx.max(projected_emd) ** (1.0 / p) if log: return res, {"projections": projections, "projected_emds": projected_emd} return res diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py new file mode 100644 index 0000000..2c470c2 --- /dev/null +++ b/test/test_1d_solver.py @@ -0,0 +1,85 @@ +"""Tests for module 1d Wasserstein solver""" + +# Author: Adrien Corenflos +# Nicolas Courty +# +# License: MIT License + +import numpy as np +import pytest + +import ot +from ot.lp import wasserstein_1d + +from ot.backend import get_backend_list +from scipy.stats import wasserstein_distance + +backend_list = get_backend_list() + + +def test_emd_1d_emd2_1d_with_weights(): + # test emd1d gives similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.randn(n, 1) + v = rng.randn(m, 1) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M = ot.dist(u, v, metric='sqeuclidean') + + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] + G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) + wass1d = log["cost"] + wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) + wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) + + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, wass1d_emd2) + + # check loss is similar to scipy's implementation for Euclidean metric + wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) + np.testing.assert_allclose(wass_sp, wass1d_euc) + + # check constraints + np.testing.assert_allclose(w_u, G.sum(1)) + np.testing.assert_allclose(w_v, G.sum(0)) + + +@pytest.mark.parametrize('nx', backend_list) +def test_wasserstein_1d(nx): + from scipy.stats import wasserstein_distance + + rng = np.random.RandomState(0) + + n = 100 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + xb = nx.from_numpy(x) + rho_ub = nx.from_numpy(rho_u) + rho_vb = nx.from_numpy(rho_v) + + # test 1 : wasserstein_1d should be close to scipy W_1 implementation + np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1), + wasserstein_distance(x, x, rho_u, rho_v)) + + # test 2 : wasserstein_1d should be close to one when only translating the support + np.testing.assert_almost_equal(wasserstein_1d(xb, xb + 1, p=2), + 1.) + + # test 3 : arrays test + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) + res = wasserstein_1d(Xb, Xb, rho_ub, rho_vb, p=2) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) diff --git a/test/test_backend.py b/test/test_backend.py index 0f11ace..1832b91 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -208,6 +208,11 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.reshape(M, (5, 3, 2)) with pytest.raises(NotImplementedError): + nx.seed(42) + with pytest.raises(NotImplementedError): + nx.rand() + with pytest.raises(NotImplementedError): + nx.randn() nx.coo_matrix(M, M, M) with pytest.raises(NotImplementedError): nx.issparse(M) @@ -248,6 +253,7 @@ def test_func_backends(nx): Mb = nx.from_numpy(M) vb = nx.from_numpy(v) + val = nx.from_numpy(val) sp_rowb = nx.from_numpy(sp_row) @@ -255,6 +261,7 @@ def test_func_backends(nx): sp_datab = nx.from_numpy(sp_data) A = nx.set_gradients(val, v, v) + lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') @@ -505,6 +512,35 @@ def test_func_backends(nx): assert np.allclose(a1, a2, atol=1e-7) +def test_random_backends(nx): + + tmp_u = nx.rand() + + assert tmp_u < 1 + + tmp_n = nx.randn() + + nx.seed(0) + M1 = nx.to_numpy(nx.rand(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.rand(5, 2, type_as=tmp_n)) + + assert np.all(M1 >= 0) + assert np.all(M1 < 1) + assert M1.shape == (5, 2) + assert np.allclose(M1, M2) + + nx.seed(0) + M1 = nx.to_numpy(nx.randn(5, 2)) + nx.seed(0) + M2 = nx.to_numpy(nx.randn(5, 2, type_as=tmp_u)) + + nx.seed(42) + v1 = nx.randn() + v2 = nx.randn() + assert v1 != v2 + + def test_gradients_backends(): rnd = np.random.RandomState(0) diff --git a/test/test_ot.py b/test/test_ot.py index 4dfc510..5bfde1d 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -8,11 +8,11 @@ import warnings import numpy as np import pytest -from scipy.stats import wasserstein_distance import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch +from scipy.stats import wasserstein_distance def test_emd_dimension_and_mass_mismatch(): @@ -165,61 +165,6 @@ def test_emd_1d_emd2_1d(): ot.emd_1d(u, v, [], []) -def test_emd_1d_emd2_1d_with_weights(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - w_u = rng.uniform(0., 1., n) - w_u = w_u / w_u.sum() - - w_v = rng.uniform(0., 1., m) - w_v = w_v / w_v.sum() - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd(w_u, w_v, M, log=True) - wass = log["cost"] - G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True) - wass1d = log["cost"] - wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False) - wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False) - - # check loss is similar - np.testing.assert_allclose(wass, wass1d) - np.testing.assert_allclose(wass, wass1d_emd2) - - # check loss is similar to scipy's implementation for Euclidean metric - wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)), w_u, w_v) - np.testing.assert_allclose(wass_sp, wass1d_euc) - - # check constraints - np.testing.assert_allclose(w_u, G.sum(1)) - np.testing.assert_allclose(w_v, G.sum(0)) - - -def test_wass_1d(): - # test emd1d gives similar results as emd - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.randn(n, 1) - v = rng.randn(m, 1) - - M = ot.dist(u, v, metric='sqeuclidean') - - G, log = ot.emd([], [], M, log=True) - wass = log["cost"] - - wass1d = ot.wasserstein_1d(u, v, [], [], p=2.) - - # check loss is similar - np.testing.assert_allclose(np.sqrt(wass), wass1d) - - def test_emd_empty(): # test emd and emd2 for simple identity n = 100 diff --git a/test/test_sliced.py b/test/test_sliced.py index a07d975..0bd74ec 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -1,6 +1,7 @@ """Tests for module sliced""" # Author: Adrien Corenflos +# Nicolas Courty # # License: MIT License @@ -14,7 +15,7 @@ from ot.sliced import get_random_projections def test_get_random_projections(): rng = np.random.RandomState(0) projections = get_random_projections(1000, 50, rng) - np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) + np.testing.assert_almost_equal(np.sum(projections ** 2, 0), 1.) def test_sliced_same_dist(): @@ -48,12 +49,12 @@ def test_sliced_log(): y = rng.randn(n, 4) u = ot.utils.unif(n) - res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, p=1, seed=rng, log=True) assert len(log) == 2 projections = log["projections"] projected_emds = log["projected_emds"] - assert len(projections) == len(projected_emds) == 10 + assert projections.shape[1] == len(projected_emds) == 10 for emd in projected_emds: assert emd > 0 @@ -83,3 +84,86 @@ def test_1d_sliced_equals_emd(): res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) np.testing.assert_almost_equal(res ** 2, expected) + + +def test_max_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.max_sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_max_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res, log = ot.max_sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert res > 0. + + +def test_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.sliced_wasserstein_distance(x, y, projections=P) + + val = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) + + +def test_max_sliced_backend(nx): + + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(2 * n, 2) + + P = rng.randn(2, 20) + P = P / np.sqrt((P**2).sum(0, keepdims=True)) + + n_projections = 20 + + xb = nx.from_numpy(x) + yb = nx.from_numpy(y) + Pb = nx.from_numpy(P) + + val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P) + + val = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + val2 = ot.max_sliced_wasserstein_distance(xb, yb, n_projections=n_projections, seed=0) + + assert val > 0 + assert val == val2 + + valb = nx.to_numpy(ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)) + + assert np.allclose(val0, valb) diff --git a/test/test_utils.py b/test/test_utils.py index 0650ce2..40f4e49 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -109,7 +109,7 @@ def test_dist(): D2 = ot.dist(x, x) D3 = ot.dist(x) - D4 = ot.dist(x, x, metric='minkowski', p=0.5) + D4 = ot.dist(x, x, metric='minkowski', p=2) assert D4[0, 1] == D4[1, 0] -- cgit v1.2.3 From e1b67c641da3b3e497db6811af2c200022b10302 Mon Sep 17 00:00:00 2001 From: Hicham Janati Date: Wed, 3 Nov 2021 08:41:35 +0100 Subject: [WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- README.md | 8 +- examples/barycenters/plot_barycenter_1D.py | 63 +- .../barycenters/plot_barycenter_lp_vs_entropic.py | 2 +- .../barycenters/plot_convolutional_barycenter.py | 53 +- examples/barycenters/plot_debiased_barycenter.py | 131 ++ .../domain-adaptation/plot_otda_color_images.py | 118 +- .../domain-adaptation/plot_otda_linear_mapping.py | 73 +- .../plot_otda_mapping_colors_images.py | 118 +- examples/gromov/plot_gromov_barycenter.py | 90 +- ot/bregman.py | 1491 +++++++++++++++----- test/test_bregman.py | 365 ++++- 11 files changed, 1837 insertions(+), 675 deletions(-) create mode 100644 examples/barycenters/plot_debiased_barycenter.py (limited to 'examples') diff --git a/README.md b/README.md index cfb9744..ff32c53 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,8 @@ POT provides the following generic OT solvers (links to examples): * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. * Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. -* Sinkhorn divergence [23] and entropic regularization OT from empirical data. +* Sinkhorn divergence [23] and entropic regularization OT from empirical data. +* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) @@ -188,7 +189,7 @@ The contributors to this library are * [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers) * [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home) * [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein) -* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT) +* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters) * [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) @@ -293,3 +294,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t (2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. + +[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International +Conference on Machine Learning, PMLR 119:4692-4701, 2020 \ No newline at end of file diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py index 63dc460..2373e99 100644 --- a/examples/barycenters/plot_barycenter_1D.py +++ b/examples/barycenters/plot_barycenter_1D.py @@ -18,10 +18,10 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138. # # License: MIT License -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 1 import numpy as np -import matplotlib.pylab as pl +import matplotlib.pyplot as plt import ot # necessary for 3d plot even if not used from mpl_toolkits.mplot3d import Axes3D # noqa @@ -50,18 +50,6 @@ n_distributions = A.shape[1] M = ot.utils.dist0(n) M /= M.max() -############################################################################## -# Plot data -# --------- - -#%% plot the distributions - -pl.figure(1, figsize=(6.4, 3)) -for i in range(n_distributions): - pl.plot(x, A[:, i]) -pl.title('Distributions') -pl.tight_layout() - ############################################################################## # Barycenter computation # ---------------------- @@ -78,24 +66,20 @@ bary_l2 = A.dot(weights) reg = 1e-3 bary_wass = ot.bregman.barycenter(A, M, reg, weights) -pl.figure(2) -pl.clf() -pl.subplot(2, 1, 1) -for i in range(n_distributions): - pl.plot(x, A[:, i]) -pl.title('Distributions') +f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1) +ax1.plot(x, A, color="black") +ax1.set_title('Distributions') -pl.subplot(2, 1, 2) -pl.plot(x, bary_l2, 'r', label='l2') -pl.plot(x, bary_wass, 'g', label='Wasserstein') -pl.legend() -pl.title('Barycenters') -pl.tight_layout() +ax2.plot(x, bary_l2, 'r', label='l2') +ax2.plot(x, bary_wass, 'g', label='Wasserstein') +ax2.set_title('Barycenters') + +plt.legend() +plt.show() ############################################################################## # Barycentric interpolation # ------------------------- - #%% barycenter interpolation n_alpha = 11 @@ -106,24 +90,23 @@ B_l2 = np.zeros((n, n_alpha)) B_wass = np.copy(B_l2) -for i in range(0, n_alpha): +for i in range(n_alpha): alpha = alpha_list[i] weights = np.array([1 - alpha, alpha]) B_l2[:, i] = A.dot(weights) B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights) #%% plot interpolation +plt.figure(2) -pl.figure(3) - -cmap = pl.cm.get_cmap('viridis') +cmap = plt.cm.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_l2[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = plt.gcf().gca(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) @@ -134,18 +117,18 @@ ax.set_ylabel('$\\alpha$') ax.set_ylim3d(0, 1) ax.set_zlabel('') ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with l2') -pl.tight_layout() +plt.title('Barycenter interpolation with l2') +plt.tight_layout() -pl.figure(4) -cmap = pl.cm.get_cmap('viridis') +plt.figure(3) +cmap = plt.cm.get_cmap('viridis') verts = [] zs = alpha_list for i, z in enumerate(zs): ys = B_wass[:, i] verts.append(list(zip(x, ys))) -ax = pl.gcf().gca(projection='3d') +ax = plt.gcf().gca(projection='3d') poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list]) poly.set_alpha(0.7) @@ -156,7 +139,7 @@ ax.set_ylabel('$\\alpha$') ax.set_ylim3d(0, 1) ax.set_zlabel('') ax.set_zlim3d(0, B_l2.max() * 1.01) -pl.title('Barycenter interpolation with Wasserstein') -pl.tight_layout() +plt.title('Barycenter interpolation with Wasserstein') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py index 57a6bac..6502f16 100644 --- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py +++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ ================================================================================= -1D Wasserstein barycenter comparison between exact LP and entropic regularization +1D Wasserstein barycenter: exact LP vs entropic regularization ================================================================================= This example illustrates the computation of regularized Wasserstein Barycenter diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index cbcd4a1..3721f31 100644 --- a/examples/barycenters/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -6,17 +6,18 @@ Convolutional Wasserstein Barycenter example ============================================ -This example is designed to illustrate how the Convolutional Wasserstein Barycenter -function of POT works. +This example is designed to illustrate how the Convolutional Wasserstein +Barycenter function of POT works. """ # Author: Nicolas Courty # # License: MIT License - +import os +from pathlib import Path import numpy as np -import pylab as pl +import matplotlib.pyplot as plt import ot ############################################################################## @@ -25,22 +26,19 @@ import ot # # The four distributions are constructed from 4 simple images +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') -f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2] -f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2] -f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2] -f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2] +f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2] +f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] +f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] -A = [] f1 = f1 / np.sum(f1) f2 = f2 / np.sum(f2) f3 = f3 / np.sum(f3) f4 = f4 / np.sum(f4) -A.append(f1) -A.append(f2) -A.append(f3) -A.append(f4) -A = np.array(A) +A = np.array([f1, f2, f3, f4]) nb_images = 5 @@ -57,14 +55,13 @@ v4 = np.array((0, 0, 0, 1)) # ---------------------------------------- # -pl.figure(figsize=(10, 10)) -pl.title('Convolutional Wasserstein Barycenters in POT') +fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7)) +plt.suptitle('Convolutional Wasserstein Barycenters in POT') cm = 'Blues' # regularization parameter reg = 0.004 for i in range(nb_images): for j in range(nb_images): - pl.subplot(nb_images, nb_images, i * nb_images + j + 1) tx = float(i) / (nb_images - 1) ty = float(j) / (nb_images - 1) @@ -74,19 +71,19 @@ for i in range(nb_images): weights = (1 - ty) * tmp1 + ty * tmp2 if i == 0 and j == 0: - pl.imshow(f1, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f1, cmap=cm) elif i == 0 and j == (nb_images - 1): - pl.imshow(f3, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f3, cmap=cm) elif i == (nb_images - 1) and j == 0: - pl.imshow(f2, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f2, cmap=cm) elif i == (nb_images - 1) and j == (nb_images - 1): - pl.imshow(f4, cmap=cm) - pl.axis('off') + axes[i, j].imshow(f4, cmap=cm) else: # call to barycenter computation - pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm) - pl.axis('off') -pl.show() + axes[i, j].imshow( + ot.bregman.convolutional_barycenter2d(A, reg, weights), + cmap=cm + ) + axes[i, j].axis('off') +plt.tight_layout() +plt.show() diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py new file mode 100644 index 0000000..2a603dd --- /dev/null +++ b/examples/barycenters/plot_debiased_barycenter.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +""" +================================= +Debiased Sinkhorn barycenter demo +================================= + +This example illustrates the computation of the debiased Sinkhorn barycenter +as proposed in [37]_. + + +.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 +""" + +# Author: Hicham Janati +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 3 + +import os +from pathlib import Path + +import numpy as np +import matplotlib.pyplot as plt + +import ot +from ot.bregman import (barycenter, barycenter_debiased, + convolutional_barycenter2d, + convolutional_barycenter2d_debiased) + +############################################################################## +# Debiased barycenter of 1D Gaussians +# ------------------------------------ + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) + +# creating matrix A containing all distributions +A = np.vstack((a1, a2)).T +n_distributions = A.shape[1] + +# loss matrix + normalization +M = ot.utils.dist0(n) +M /= M.max() + +#%% barycenter computation + +alpha = 0.2 # 0<=alpha<=1 +weights = np.array([1 - alpha, alpha]) + +epsilons = [5e-3, 1e-2, 5e-2] + + +bars = [barycenter(A, M, reg, weights) for reg in epsilons] +bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons] +labels = ["Sinkhorn barycenter", "Debiased barycenter"] +colors = ["indianred", "gold"] + +f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True, + figsize=(12, 4), num=1) +for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased): + ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3) + ax.plot(A[:, 1], color="k", ls="--", alpha=0.3) + for data, label, color in zip([bar, bar_debiased], labels, colors): + ax.plot(data, color=color, label=label, lw=2) + ax.set_title(r"$\varepsilon = %.3f$" % eps) +plt.legend() +plt.show() + + +############################################################################## +# Debiased barycenter of 2D images +# --------------------------------- +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2] +f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2] + +A = np.asarray([f1, f2]) + 1e-2 +A /= A.sum(axis=(1, 2))[:, None, None] + +############################################################################## +# Display the input images + +fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2) +for ax, img in zip(axes, A): + ax.imshow(img, cmap="Greys") + ax.axis("off") +fig.tight_layout() +plt.show() + + +############################################################################## +# Barycenter computation and visualization +# ---------------------------------------- +# + +bars_sinkhorn, bars_debiased = [], [] +epsilons = [5e-3, 7e-3, 1e-2] +for eps in epsilons: + bar = convolutional_barycenter2d(A, eps) + bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True) + bars_sinkhorn.append(bar) + bars_debiased.append(bar_debiased) + +titles = ["Sinkhorn", "Debiased"] +all_bars = [bars_sinkhorn, bars_debiased] +fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3) +for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)): + for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)): + ax.imshow(img, cmap="Greys") + if jj == 0: + ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13) + ax.set_xticks([]) + ax.set_yticks([]) + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['bottom'].set_visible(False) + ax.spines['left'].set_visible(False) + if ii == 0: + ax.set_ylabel(method, fontsize=15) +fig.tight_layout() +plt.show() diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 6218b13..06dc8ab 100644 --- a/examples/domain-adaptation/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -19,12 +19,15 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882. # sphinx_gallery_thumbnail_number = 2 +import os +from pathlib import Path + import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -46,16 +49,19 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -65,39 +71,39 @@ Xt = X2[idx2, :] # Plot original image # ------------------- -pl.figure(1, figsize=(6.4, 3)) +plt.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') ############################################################################## # Scatter plot of colors # ---------------------- -pl.figure(2, figsize=(6.4, 3)) +plt.figure(2, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## @@ -130,37 +136,37 @@ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape)) # Plot new images # --------------- -pl.figure(3, figsize=(8, 4)) +plt.figure(3, figsize=(8, 4)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(2, 3, 2) -pl.imshow(I1t) -pl.axis('off') -pl.title('Image 1 Adapt') +plt.subplot(2, 3, 2) +plt.imshow(I1t) +plt.axis('off') +plt.title('Image 1 Adapt') -pl.subplot(2, 3, 3) -pl.imshow(I1te) -pl.axis('off') -pl.title('Image 1 Adapt (reg)') +plt.subplot(2, 3, 3) +plt.imshow(I1te) +plt.axis('off') +plt.title('Image 1 Adapt (reg)') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') -pl.subplot(2, 3, 5) -pl.imshow(I2t) -pl.axis('off') -pl.title('Image 2 Adapt') +plt.subplot(2, 3, 5) +plt.imshow(I2t) +plt.axis('off') +plt.title('Image 2 Adapt') -pl.subplot(2, 3, 6) -pl.imshow(I2te) -pl.axis('off') -pl.title('Image 2 Adapt (reg)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(I2te) +plt.axis('off') +plt.title('Image 2 Adapt (reg)') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index be47510..a44096a 100644 --- a/examples/domain-adaptation/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -13,9 +13,11 @@ Linear OT mapping estimation # License: MIT License # sphinx_gallery_thumbnail_number = 2 +import os +from pathlib import Path import numpy as np -import pylab as pl +from matplotlib import pyplot as plt import ot ############################################################################## @@ -26,17 +28,19 @@ n = 1000 d = 2 sigma = .1 +rng = np.random.RandomState(42) + # source samples -angles = np.random.rand(n, 1) * 2 * np.pi +angles = rng.rand(n, 1) * 2 * np.pi xs = np.concatenate((np.sin(angles), np.cos(angles)), - axis=1) + sigma * np.random.randn(n, 2) + axis=1) + sigma * rng.randn(n, 2) xs[:n // 2, 1] += 2 # target samples -anglet = np.random.rand(n, 1) * 2 * np.pi +anglet = rng.rand(n, 1) * 2 * np.pi xt = np.concatenate((np.sin(anglet), np.cos(anglet)), - axis=1) + sigma * np.random.randn(n, 2) + axis=1) + sigma * rng.randn(n, 2) xt[:n // 2, 1] += 2 @@ -48,9 +52,9 @@ xt = xt.dot(A) + b # Plot data # --------- -pl.figure(1, (5, 5)) -pl.plot(xs[:, 0], xs[:, 1], '+') -pl.plot(xt[:, 0], xt[:, 1], 'o') +plt.figure(1, (5, 5)) +plt.plot(xs[:, 0], xs[:, 1], '+') +plt.plot(xt[:, 0], xt[:, 1], 'o') ############################################################################## @@ -66,13 +70,13 @@ xst = xs.dot(Ae) + be # Plot transported samples # ------------------------ -pl.figure(1, (5, 5)) -pl.clf() -pl.plot(xs[:, 0], xs[:, 1], '+') -pl.plot(xt[:, 0], xt[:, 1], 'o') -pl.plot(xst[:, 0], xst[:, 1], '+') +plt.figure(1, (5, 5)) +plt.clf() +plt.plot(xs[:, 0], xs[:, 1], '+') +plt.plot(xt[:, 0], xt[:, 1], 'o') +plt.plot(xst[:, 0], xst[:, 1], '+') -pl.show() +plt.show() ############################################################################## # Load image data @@ -94,8 +98,11 @@ def minmax(img): # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) @@ -123,24 +130,24 @@ I2t = minmax(mat2im(xts, I2.shape)) # Plot transformed images # ----------------------- -pl.figure(2, figsize=(10, 7)) +plt.figure(2, figsize=(10, 7)) -pl.subplot(2, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Im. 1') +plt.subplot(2, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Im. 1') -pl.subplot(2, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Im. 2') +plt.subplot(2, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Im. 2') -pl.subplot(2, 2, 3) -pl.imshow(I1t) -pl.axis('off') -pl.title('Mapping Im. 1') +plt.subplot(2, 2, 3) +plt.imshow(I1t) +plt.axis('off') +plt.title('Mapping Im. 1') -pl.subplot(2, 2, 4) -pl.imshow(I2t) -pl.axis('off') -pl.title('Inverse mapping Im. 2') +plt.subplot(2, 2, 4) +plt.imshow(I2t) +plt.axis('off') +plt.title('Inverse mapping Im. 2') diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index 72010a6..dbece70 100644 --- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -21,12 +21,14 @@ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. # License: MIT License # sphinx_gallery_thumbnail_number = 3 +import os +from pathlib import Path import numpy as np -import matplotlib.pylab as pl +from matplotlib import pyplot as plt import ot -r = np.random.RandomState(42) +rng = np.random.RandomState(42) def im2mat(img): @@ -48,17 +50,19 @@ def minmax(img): # ------------- # Loading images -I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') +I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256 +I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) # training samples nb = 500 -idx1 = r.randint(X1.shape[0], size=(nb,)) -idx2 = r.randint(X2.shape[0], size=(nb,)) +idx1 = rng.randint(X1.shape[0], size=(nb,)) +idx2 = rng.randint(X2.shape[0], size=(nb,)) Xs = X1[idx1, :] Xt = X2[idx2, :] @@ -99,76 +103,76 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape)) # Plot original images # -------------------- -pl.figure(1, figsize=(6.4, 3)) -pl.subplot(1, 2, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Image 1') +plt.figure(1, figsize=(6.4, 3)) +plt.subplot(1, 2, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.imshow(I2) -pl.axis('off') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.imshow(I2) +plt.axis('off') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot pixel values distribution # ------------------------------ -pl.figure(2, figsize=(6.4, 5)) +plt.figure(2, figsize=(6.4, 5)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 1') +plt.subplot(1, 2, 1) +plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 1') -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt) -pl.axis([0, 1, 0, 1]) -pl.xlabel('Red') -pl.ylabel('Blue') -pl.title('Image 2') -pl.tight_layout() +plt.subplot(1, 2, 2) +plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt) +plt.axis([0, 1, 0, 1]) +plt.xlabel('Red') +plt.ylabel('Blue') +plt.title('Image 2') +plt.tight_layout() ############################################################################## # Plot transformed images # ----------------------- -pl.figure(2, figsize=(10, 5)) +plt.figure(2, figsize=(10, 5)) -pl.subplot(2, 3, 1) -pl.imshow(I1) -pl.axis('off') -pl.title('Im. 1') +plt.subplot(2, 3, 1) +plt.imshow(I1) +plt.axis('off') +plt.title('Im. 1') -pl.subplot(2, 3, 4) -pl.imshow(I2) -pl.axis('off') -pl.title('Im. 2') +plt.subplot(2, 3, 4) +plt.imshow(I2) +plt.axis('off') +plt.title('Im. 2') -pl.subplot(2, 3, 2) -pl.imshow(Image_emd) -pl.axis('off') -pl.title('EmdTransport') +plt.subplot(2, 3, 2) +plt.imshow(Image_emd) +plt.axis('off') +plt.title('EmdTransport') -pl.subplot(2, 3, 5) -pl.imshow(Image_sinkhorn) -pl.axis('off') -pl.title('SinkhornTransport') +plt.subplot(2, 3, 5) +plt.imshow(Image_sinkhorn) +plt.axis('off') +plt.title('SinkhornTransport') -pl.subplot(2, 3, 3) -pl.imshow(Image_mapping_linear) -pl.axis('off') -pl.title('MappingTransport (linear)') +plt.subplot(2, 3, 3) +plt.imshow(Image_mapping_linear) +plt.axis('off') +plt.title('MappingTransport (linear)') -pl.subplot(2, 3, 6) -pl.imshow(Image_mapping_gaussian) -pl.axis('off') -pl.title('MappingTransport (gaussian)') -pl.tight_layout() +plt.subplot(2, 3, 6) +plt.imshow(Image_mapping_gaussian) +plt.axis('off') +plt.title('MappingTransport (gaussian)') +plt.tight_layout() -pl.show() +plt.show() diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index e2d88ba..7fe081f 100755 --- a/examples/gromov/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -13,11 +13,13 @@ computation in POT. # # License: MIT License +import os +from pathlib import Path import numpy as np import scipy as sp -import matplotlib.pylab as pl +from matplotlib import pyplot as plt from sklearn import manifold from sklearn.decomposition import PCA @@ -89,17 +91,19 @@ def im2mat(img): return img.reshape((img.shape[0] * img.shape[1], img.shape[2])) -square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2] -cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2] -triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2] -star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2] +this_file = os.path.realpath('__file__') +data_path = os.path.join(Path(this_file).parent.parent.parent, 'data') + +square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2] +cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2] +triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2] +star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2] shapes = [square, cross, triangle, star] S = 4 xs = [[] for i in range(S)] - for nb in range(4): for i in range(8): for j in range(8): @@ -184,64 +188,64 @@ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)] npost23 = [clf.fit_transform(npost23[s]) for s in range(2)] -fig = pl.figure(figsize=(10, 10)) +fig = plt.figure(figsize=(10, 10)) -ax1 = pl.subplot2grid((4, 4), (0, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax1 = plt.subplot2grid((4, 4), (0, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r') -ax2 = pl.subplot2grid((4, 4), (0, 1)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax2 = plt.subplot2grid((4, 4), (0, 1)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b') -ax3 = pl.subplot2grid((4, 4), (0, 2)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax3 = plt.subplot2grid((4, 4), (0, 2)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b') -ax4 = pl.subplot2grid((4, 4), (0, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax4 = plt.subplot2grid((4, 4), (0, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r') -ax5 = pl.subplot2grid((4, 4), (1, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax5 = plt.subplot2grid((4, 4), (1, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b') -ax6 = pl.subplot2grid((4, 4), (1, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax6 = plt.subplot2grid((4, 4), (1, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b') -ax7 = pl.subplot2grid((4, 4), (2, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax7 = plt.subplot2grid((4, 4), (2, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b') -ax8 = pl.subplot2grid((4, 4), (2, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax8 = plt.subplot2grid((4, 4), (2, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b') -ax9 = pl.subplot2grid((4, 4), (3, 0)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax9 = plt.subplot2grid((4, 4), (3, 0)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r') -ax10 = pl.subplot2grid((4, 4), (3, 1)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax10 = plt.subplot2grid((4, 4), (3, 1)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b') -ax11 = pl.subplot2grid((4, 4), (3, 2)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax11 = plt.subplot2grid((4, 4), (3, 2)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b') -ax12 = pl.subplot2grid((4, 4), (3, 3)) -pl.xlim((-1, 1)) -pl.ylim((-1, 1)) +ax12 = plt.subplot2grid((4, 4), (3, 3)) +plt.xlim((-1, 1)) +plt.ylim((-1, 1)) ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r') diff --git a/ot/bregman.py b/ot/bregman.py index 0499b8e..786f151 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -7,7 +7,7 @@ Bregman projections solvers for entropic regularized OT # Nicolas Courty # Kilian Fatras # Titouan Vayer -# Hicham Janati +# Hicham Janati # Mokhtar Z. Alaya # Alexander Tong # Ievgen Redko @@ -25,7 +25,8 @@ from .backend import get_backend def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -43,8 +44,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. @@ -77,7 +80,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -94,6 +98,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -117,13 +123,21 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. See Also @@ -131,37 +145,44 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn + :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling + :ref:`[9] ` :ref:`[10] ` """ if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'greenkhorn': return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log) + stopThr=stopThr, verbose=verbose, log=log, + warn=warn) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) elif method.lower() == 'sinkhorn_epsilon_scaling': return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, + **kwargs) else: raise ValueError("Unknown method '%s'." % method) def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): + stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs): r""" Solve the entropic regularization optimal transport problem and return the loss @@ -179,13 +200,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) .. note:: This function is backend-compatible and will work on arrays from all compatible backends. - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] ` **Choosing a Sinkhorn solver** @@ -212,7 +236,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -228,6 +253,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -252,19 +279,27 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation - algorithms for optimal transport via Sinkhorn iteration, Advances in Neural - Information Processing Systems (NIPS) 31, 2017 - - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. - + algorithms for optimal transport via Sinkhorn iteration, + Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). + Interpolating between optimal transport and MMD using Sinkhorn + divergences. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2681-2690). PMLR. See Also -------- @@ -272,7 +307,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, ot.optim.cg : General regularized OT ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` ot.bregman.greenkhorn : Greenkhorn :ref:`[21] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] ` + :ref:`[10] ` """ @@ -317,8 +353,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, raise ValueError("Unknown method '%s'." % method) -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, + **kwargs): r""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -335,10 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp + matrix scaling algorithm as proposed in :ref:`[2] ` Parameters @@ -347,7 +387,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, samples weights in the source domain b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -360,6 +401,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -384,7 +427,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 See Also @@ -427,9 +472,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, K = nx.exp(M / (-reg)) Kp = (1 / a).reshape(-1, 1) * K - cpt = 0 + err = 1 - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): uprev = u vprev = v KtransposeU = nx.dot(K.T, u) @@ -441,11 +486,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Warning: numerical errors at iteration %d' % ii) u = uprev v = vprev break - if cpt % 10 == 0: + if ii % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: @@ -457,13 +502,20 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log['niter'] = ii log['u'] = u log['v'] = v @@ -482,8 +534,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_log(a, b, M, reg, numItermax=1000, - stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, + log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem in log space and return the OT matrix @@ -528,6 +580,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -552,9 +606,15 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). Interpolating between + optimal transport and MMD using Sinkhorn divergences. In The + 22nd International Conference on Artificial Intelligence and + Statistics (pp. 2681-2690). PMLR. See Also @@ -613,7 +673,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, if log: log = {'err': []} - Mr = M / (-reg) + Mr = - M / reg # we assume that no distances are null except those of the diagonal of # distances @@ -630,14 +690,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, loga = nx.log(a) logb = nx.log(b) - cpt = 0 err = 1 - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): v = logb - nx.logsumexp(Mr + u[:, None], 0) u = loga - nx.logsumexp(Mr + v[None, :], 1) - if cpt % 10 == 0: + if ii % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations @@ -648,13 +707,20 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - cpt = cpt + 1 + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log['niter'] = ii log['log_u'] = u log['log_v'] = v log['u'] = nx.exp(u) @@ -667,11 +733,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - log=False): + log=False, warn=True): r""" Solve the entropic regularization optimal transport problem and return the OT matrix - The algorithm used is based on the paper :ref:`[22] ` which is a stochastic version of the Sinkhorn-Knopp algorithm :ref:`[2] ` + The algorithm used is based on the paper :ref:`[22] ` + which is a stochastic version of the Sinkhorn-Knopp + algorithm :ref:`[2] ` The function solves the following optimization problem: @@ -686,8 +754,10 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) Parameters @@ -696,7 +766,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, samples weights in the source domain b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) M : array-like, shape (dim_a, dim_b) loss matrix reg : float @@ -707,6 +778,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, Stop threshold on error (>0) log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -731,9 +804,14 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 + .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time + approximation algorithms for optimal transport via Sinkhorn + iteration, Advances in Neural Information Processing + Systems (NIPS) 31, 2017 See Also @@ -747,7 +825,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, nx = get_backend(M, a, b) if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received. Greenkhorn is not compatible with JAX") + raise TypeError("JAX arrays have been received. Greenkhorn is not " + "compatible with JAX") if len(a) == 0: a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] @@ -771,7 +850,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log['u'] = u log['v'] = v - for i in range(numItermax): + for ii in range(numItermax): i_1 = nx.argmax(nx.abs(viol)) i_2 = nx.argmax(nx.abs(viol_2)) m_viol_1 = nx.abs(viol[i_1]) @@ -795,14 +874,17 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, viol += (-old_v + new_v) * K[:, i_2] * u viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] v[i_2] = new_v - # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) if stopThr_val <= stopThr: break else: - print('Warning: Algorithm did not converge') + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: + log["n_iter"] = ii log['u'] = u log['v'] = v @@ -814,7 +896,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization OT problem with log stabilization @@ -831,13 +913,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization - proposed in :ref:`[10] ` an defined in :ref:`[9] ` (Algo 3.1) . + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` an defined in + :ref:`[9] ` (Algo 3.1) . Parameters @@ -851,7 +937,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, reg : float Regularization term >0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling warmstart : table of vectors if given then starting values for alpha and beta log scalings numItermax : int, optional @@ -862,6 +949,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -886,11 +975,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. See Also @@ -920,7 +1015,6 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, dim_a = len(a) dim_b = len(b) - cpt = 0 if log: log = {'err': []} @@ -935,7 +1029,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, u = nx.ones((dim_a, n_hists), type_as=M) / dim_a v = nx.ones((dim_b, n_hists), type_as=M) / dim_b else: - u, v = nx.ones(dim_a, type_as=M) / dim_a, nx.ones(dim_b, type_as=M) / dim_b + u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M) + u /= dim_a + v /= dim_b def get_K(alpha, beta): """log space computation""" @@ -947,21 +1043,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) - # print(np.min(K)) - K = get_K(alpha, beta) transp = K - loop = 1 - cpt = 0 err = 1 - while loop: + for ii in range(numItermax): uprev = u vprev = v # sinkhorn update - v = b / (nx.dot(K.T, u) + 1e-16) - u = a / (nx.dot(K, v) + 1e-16) + v = b / (nx.dot(K.T, u)) + u = a / (nx.dot(K, v)) # remove numerical problems and store them in K if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: @@ -977,7 +1069,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, v = nx.ones(dim_b, type_as=M) / dim_b K = get_K(alpha, beta) - if cpt % print_period == 0: + if ii % print_period == 0: # we can speed up the process by checking for the error only all # the 10th iterations if n_hists: @@ -993,33 +1085,33 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, log['err'].append(err) if verbose: - if cpt % (print_period * 20) == 0: + if ii % (print_period * 20) == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) if err <= stopThr: - loop = False - - if cpt >= numItermax: - loop = False + break if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): # we have reached the machine precision # come back to previous solution and quit loop - print('Warning: numerical errors at iteration', cpt) + warnings.warn('Numerical errors at iteration %d' % ii) u = uprev v = vprev break - - cpt = cpt + 1 - + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: if n_hists: alpha = alpha[:, None] beta = beta[:, None] logu = alpha / reg + nx.log(u) logv = beta / reg + nx.log(v) + log["n_iter"] = ii log['logu'] = logu log['logv'] = logv log['alpha'] = alpha + reg * nx.log(u) @@ -1048,13 +1140,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, - log=False, **kwargs): + log=False, warn=True, **kwargs): r""" Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. - The function solves the following optimization problem: - .. math:: \gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma) @@ -1064,16 +1154,16 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, \gamma &\geq 0 where : - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - - + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights + (histograms, both sum to 1) The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` but with the log stabilization - proposed in :ref:`[10] ` and the log scaling proposed in :ref:`[9] ` algorithm 3.2 - + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` and the log scaling + proposed in :ref:`[9] ` algorithm 3.2 Parameters ---------- @@ -1086,7 +1176,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, reg : float Regularization term >0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` + for log scaling warmstart : tuple of vectors if given then starting values for alpha and beta log scalings numItermax : int, optional @@ -1101,6 +1192,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1108,10 +1201,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters - Examples -------- - >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] @@ -1123,19 +1214,19 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, .. _references-sinkhorn-epsilon-scaling: References ---------- + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT - """ a, b, M = list_to_array(a, b, M) @@ -1155,7 +1246,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numItermin = 35 numItermax = max(numItermin, numItermax) # ensure that last velue is exact - cpt = 0 + ii = 0 if log: log = {'err': []} @@ -1170,12 +1261,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def get_reg(n): # exponential decreasing return (epsilon0 - reg) * np.exp(-n) + reg - loop = 1 - cpt = 0 err = 1 - while loop: + for ii in range(numItermax): - regi = get_reg(cpt) + regi = get_reg(ii) G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, @@ -1185,10 +1274,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, alpha = logi['alpha'] beta = logi['beta'] - if cpt >= numItermax: - loop = False - - if cpt % (print_period) == 0: # spsion nearly converged + if ii % (print_period) == 0: # spsion nearly converged # we can speed up the process by checking for the error only all # the 10th iterations transp = G @@ -1197,19 +1283,22 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, log['err'].append(err) if verbose: - if cpt % (print_period * 10) == 0: + if ii % (print_period * 10) == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - if err <= stopThr and cpt > numItermin: - loop = False + print('{:5d}|{:8e}|'.format(ii, err)) - cpt = cpt + 1 - # print('err=',err,' cpt=',cpt) + if err <= stopThr and ii > numItermin: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: log['alpha'] = alpha log['beta'] = beta log['warmstart'] = (log['alpha'], log['beta']) + log['niter'] = ii return G, log else: return G @@ -1245,7 +1334,7 @@ def projC(gamma, q): def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, **kwargs): + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -1255,11 +1344,16 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` Parameters ---------- @@ -1270,7 +1364,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, reg : float Regularization term > 0 method : str (optional) - method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' + method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log' weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional @@ -1281,6 +1375,8 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1295,7 +1391,9 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1303,18 +1401,24 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, return barycenter_sinkhorn(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return barycenter_stabilized(A, M, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, - log=log, **kwargs) + log=log, warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_sinkhorn_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` The function solves the following optimization problem: @@ -1324,11 +1428,15 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[3]`. Parameters ---------- @@ -1348,6 +1456,8 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1362,7 +1472,9 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1378,43 +1490,109 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, if log: log = {'err': []} - # M = M/np.median(M) # suggested by G. Peyre K = nx.exp(-M / reg) - cpt = 0 err = 1 UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) u = (geometricMean(UKv) / UKv.T).T - while (err > stopThr and cpt < numItermax): - cpt = cpt + 1 + for ii in range(numItermax): + UKv = u * nx.dot(K, A / nx.dot(K, u)) u = (u.T * geometricBar(weights, UKv)).T / UKv - if cpt % 10 == 1: + if ii % 10 == 1: err = nx.sum(nx.std(UKv, axis=1)) # log and verbose print if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii return geometricBar(weights, UKv), log else: return geometricBar(weights, UKv) +def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the entropic wasserstein barycenter in log-domain + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar = log_bar + weights[k] * log_KU[:, k] + + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False): + stopThr=1e-4, verbose=False, log=False, warn=True): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: @@ -1424,11 +1602,15 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` Parameters ---------- @@ -1439,7 +1621,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, reg : float Regularization term > 0 tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling weights : array-like, shape (n_hists,) Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) numItermax : int, optional @@ -1450,6 +1633,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -1464,7 +1649,9 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, References ---------- - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. """ @@ -1486,19 +1673,18 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, K = nx.exp(-M / reg) - cpt = 0 err = 1. alpha = nx.zeros((dim,), type_as=M) beta = nx.zeros((dim,), type_as=M) q = nx.ones((dim,), type_as=M) / dim - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): qprev = q Kv = nx.dot(K, v) - u = A / (Kv + 1e-16) + u = A / Kv Ktu = nx.dot(K.T, u) q = geometricBar(weights, Ktu) Q = q[:, None] - v = Q / (Ktu + 1e-16) + v = Q / Ktu absorbing = False if nx.any(u > tau) or nx.any(v > tau): absorbing = True @@ -1512,40 +1698,244 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % cpt) + warnings.warn('Numerical errors at iteration %s' % ii) q = qprev break - if (cpt % 10 == 0 and not absorbing) or cpt == 0: + if (ii % 10 == 0 and not absorbing) or ii == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = nx.max(nx.abs(u * Kv - A)) if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 50 == 0: + if ii % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) - cpt += 1 - if err > stopThr: - warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + - "Try a larger entropy `reg`" + - "Or a larger absorption threshold `tau`.") + else: + if warn: + warnings.warn("Stabilized Sinkhorn did not converge." + + "Try a larger entropy `reg`" + + "Or a larger absorption threshold `tau`.") if log: - log['niter'] = cpt - log['logu'] = nx.log(u + 1e-16) - log['logv'] = nx.log(v + 1e-16) + log['niter'] = ii + log['logu'] = np.log(u + 1e-16) + log['logv'] = np.log(v + 1e-16) return q, log else: return q -def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, - stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False): - r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` - where :math:`\mathbf{A}` is a collection of 2D images. +def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): + r"""Compute the debiased Sinkhorn barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence + (see :py:func:`ot.bregman.emirical_sinkhorn_divergence`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + + The algorithm used for solving the problem is the debiased Sinkhorn + algorithm as proposed in :ref:`[37] ` + + Parameters + ---------- + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`a_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + method : str (optional) + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + + Returns + ------- + a : (dim,) array-like + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + .. _references-sinkhorn-debiased: + References + ---------- + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _barycenter_debiased(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_debiased_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the debiased sinkhorn barycenter of distributions A. + """ + + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + + if weights is None: + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = nx.exp(-M / reg) + + err = 1 + + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + + u = (geometricMean(UKv) / UKv.T).T + c = nx.ones(A.shape[0], type_as=A) + bar = nx.ones(A.shape[0], type_as=A) + + for ii in range(numItermax): + bold = bar + UKv = nx.dot(K, A / nx.dot(K, u)) + bar = c * geometricBar(weights, UKv) + u = bar[:, None] / UKv + c = (c * bar / nx.dot(K, c)) ** 0.5 + + if ii % 10 == 9: + err = abs(bar - bold).max() / max(bar.max(), 1.) + + # log and verbose print + if log: + log['err'].append(err) + + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return bar, log + else: + return bar + + +def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, + warn=True): + r"""Compute the debiased sinkhorn barycenter in log domain. + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + c = nx.zeros(dim, type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar += weights[k] * log_KU[:, k] + log_bar += c + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + for _ in range(10): + c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, + warn=True, **kwargs): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. The function solves the following optimization problem: @@ -1554,11 +1944,14 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, where : - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions + of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] ` + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm + as proposed in :ref:`[21] ` Parameters ---------- @@ -1568,6 +1961,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, Regularization term >0 weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional Max number of iterations stopThr : float, optional @@ -1578,6 +1973,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1591,9 +1988,36 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, References ---------- - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, + A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: + Efficient optimal transportation on geometric domains. ACM Transactions + on Graphics (TOG), 34(4), 66 + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, + stopThr=1e-9, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. """ A = list_to_array(A) @@ -1608,65 +2032,373 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if log: log = {'err': []} - b = nx.zeros(A.shape[1:], type_as=A) + bar = nx.ones(A.shape[1:], type_as=A) + bar /= bar.sum() U = nx.ones(A.shape, type_as=A) - KV = nx.ones(A.shape, type_as=A) - - cpt = 0 + V = nx.ones(A.shape, type_as=A) err = 1 # build the convolution operator # this is equivalent to blurring on horizontal then vertical directions t = nx.linspace(0, 1, A.shape[1]) [Y, X] = nx.meshgrid(t, t) - xi1 = nx.exp(-(X - Y) ** 2 / reg) + K1 = nx.exp(-(X - Y) ** 2 / reg) t = nx.linspace(0, 1, A.shape[2]) [Y, X] = nx.meshgrid(t, t) - xi2 = nx.exp(-(X - Y) ** 2 / reg) - - def K(x): - return nx.dot(nx.dot(xi1, x), xi2) - - while (err > stopThr and cpt < numItermax): - - bold = b - cpt = cpt + 1 - - b = nx.zeros(A.shape[1:], type_as=A) - KV_cols = [] - for r in range(A.shape[0]): - KV_col_r = K(A[r, :, :] / nx.maximum(stabThr, K(U[r, :, :]))) - b += weights[r] * nx.log(nx.maximum(stabThr, U[r, :, :] * KV_col_r)) - KV_cols.append(KV_col_r) - KV = nx.stack(KV_cols) - b = nx.exp(b) - - U = nx.stack([ - b / nx.maximum(stabThr, KV[r, :, :]) - for r in range(A.shape[0]) - ]) - if cpt % 10 == 1: - err = nx.sum(nx.abs(bold - b)) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + if ii % 10 == 9: + err = (V * KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + log['U'] = U + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images in log-domain. + """ + + A = list_to_array(A) + + nx = get_backend(A) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + + n_hists, width, height = A.shape + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + G = log_bar[None, :, :] - log_KU + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", + numItermax=10000, stopThr=1e-3, + verbose=False, log=False, warn=True, + **kwargs): + r"""Compute the debiased sinkhorn barycenter of distributions A + where A is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn_debiased`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two + dimensions of matrix :math:`\mathbf{A}` + - `reg` is the regularization strength scalar value + + The algorithm used for solving the problem is the debiased Sinkhorn scaling + algorithm as proposed in :ref:`[37] ` + + Parameters + ---------- + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` + reg : float + Regularization term >0 + weights : array-like, shape (n_hists,) + Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : array-like, shape (width, height) + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-sinkhorn-debiased: + References + ---------- + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d_debiased(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-15, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + + nx = get_backend(A) + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + bar = nx.ones((width, height), type_as=A) + bar /= width * height + U = nx.ones(A.shape, type_as=A) + V = nx.ones(A.shape, type_as=A) + c = nx.ones(A.shape[1:], type_as=A) + err = 1 + + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + K1 = nx.exp(-(X - Y) ** 2 / reg) + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + + for _ in range(10): + c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5 + + if ii % 10 == 9: + err = (V * KU).std(axis=0).sum() # log and verbose print if log: log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(ii, err)) + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii log['U'] = U - return b, log + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images in log-domain. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + nx = get_backend(A) + if nx.__name__ == "jax": + raise NotImplementedError("Log-domain functions are not yet implemented" + " for Jax. Use numpy or torch arrays instead.") + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_bar, c = nx.zeros((2, width, height), type_as=A) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + log_bar += c + for _ in range(10): + c = 0.5 * (c + log_bar - convol_img(c)) + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr and ii > 20: + break + G = log_bar[None, :, :] - log_KU + else: - return b + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, - stopThr=1e-3, verbose=False, log=False): + stopThr=1e-3, verbose=False, log=False, warn=True): r""" Compute the unmixing of an observation with a given dictionary using Wasserstein distance @@ -1679,16 +2411,21 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, where : - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)` + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + with M loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, + its expected shape is `(dim_a, n_atoms)` - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (`dim_a`, `dim_a`) for OT data fitting - - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization term and the cost matrix (`dim_prior`, `n_atoms`) regularization + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the + cost matrix (`dim_a`, `dim_a`) for OT data fitting + - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization + term and the cost matrix (`dim_prior`, `n_atoms`) regularization - :math:`\\alpha` weight data fitting and regularization - The optimization problem is solved following the algorithm described in :ref:`[4] ` + The optimization problem is solved following the algorithm described + in :ref:`[4] ` Parameters @@ -1717,7 +2454,8 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Print information along iterations log : bool, optional record log if True - + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1731,8 +2469,10 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, References ---------- - .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. - + .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, + Supervised planetary unmixing with optimal transport, Whorkshop + on Hyperspectral Image and Signal Processing : + Evolution in Remote Sensing (WHISPERS), 2016. """ a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0) @@ -1747,12 +2487,11 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, old = h0 err = 1 - cpt = 0 # log = {'niter':0, 'all_err':[]} if log: log = {'err': []} - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): K = projC(K, a) K0 = projC(K0, h0) new = nx.sum(K0, axis=1) @@ -1770,22 +2509,27 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, log['err'].append(err) if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - - cpt = cpt + 1 - + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Unmixing algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: - log['niter'] = cpt + log['niter'] = ii return nx.sum(K0, axis=1), log else: return nx.sum(K0, axis=1) def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, - stopThr=1e-6, verbose=False, log=False, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] ` + stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs): + r'''Joint OT and proportion estimation for multi-source target shift as + proposed in :ref:`[27] ` The function solves the following optimization problem: @@ -1799,16 +2543,23 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, where : - :math:`\lambda_k` is the weight of `k`-th source domain - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source domain and `C` is the number of classes + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain + defined as in [p. 5, :ref:`27 `], its expected shape + is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source + domain and `C` is the number of classes - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in + [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` - The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + The problem consist in solving a Wasserstein barycenter problem to estimate + the proportions :math:`\mathbf{h}` in the target domain. The algorithm used for solving the problem is the Iterative Bregman projections algorithm - with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution. + with two sets of marginal constraints related to the unknown vector + :math:`\mathbf{h}` and uniform target distribution. Parameters ---------- @@ -1826,10 +2577,12 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Max number of iterations stopThr : float, optional Stop threshold on relative change in the barycenter (>0) - log : bool, optional - record log if True verbose : bool, optional (default=False) Controls the verbosity of the optimization algorithm + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -1844,9 +2597,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, ---------- .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia - "Optimal transport for multi-source domain adaptation under target shift", - International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. ''' Xs = list_to_array(*Xs) @@ -1901,11 +2653,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # uniform target distribution a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) - cpt = 0 # iterations count err = 1 old_bary = nx.ones((nbclasses,), type_as=Xs[0]) - while (err > stopThr and cpt < numItermax): + for ii in range(numItermax): bary = nx.zeros((nbclasses,), type_as=Xs[0]) @@ -1923,21 +2674,27 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, K[d] = projR(K[d], new) err = nx.norm(bary - old_bary) - cpt = cpt + 1 + old_bary = bary if log: log['err'].append(err) + if err < stopThr: + break if verbose: - if cpt % 200 == 0: + if ii % 200 == 0: print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) - + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") bary = bary / nx.sum(bary) if log: - log['niter'] = cpt + log['niter'] = ii log['M'] = M log['D1'] = D1 log['D2'] = D2 @@ -1949,7 +2706,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, - log=False, **kwargs): + log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the OT matrix from empirical data @@ -1967,7 +2724,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -1988,7 +2746,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate full + cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. @@ -1996,6 +2756,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -2021,11 +2783,14 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' X_s, X_t = list_to_array(X_s, X_t) @@ -2100,7 +2865,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if err <= stopThr: break - + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") if log: dict_log["u"] = f dict_log["v"] = g @@ -2111,15 +2880,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', else: M = dist(X_s, X_t, metric=metric) if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=True, **kwargs) return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=False, **kwargs) return pi -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, isLazy=False, + batchSize=100, verbose=False, log=False, warn=True, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -2138,7 +2910,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num where : - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2159,7 +2932,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num stopThr : float, optional Stop threshold on error (>0) isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function. + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate + full cost matrix and return outputs of sinkhorn function. batchSize: int or tuple of 2 int, optional Size of the batches used to compute the sinkhorn update without memory overhead. When a tuple is provided it sets the size of the left/right batches. @@ -2167,6 +2942,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns @@ -2192,11 +2969,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num References ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling + Algorithms for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. ''' X_s, X_t = list_to_array(X_s, X_t) @@ -2211,11 +2994,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num if isLazy: if log: - f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, + stopThr=stopThr, + isLazy=isLazy, + batchSize=batchSize, + verbose=verbose, log=log, + warn=warn) else: - f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, + verbose=verbose, log=log, + warn=warn) bs = batchSize if isinstance(batchSize, int) else batchSize[0] range_s = range(0, ns, bs) @@ -2241,17 +3032,21 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num M = nx.from_numpy(M, type_as=a) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) return sinkhorn_loss -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, warn=True, + **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -2288,8 +3083,11 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli \gamma_b &\geq 0 where : - - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) - - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) + is the (`n_samples_a`, `n_samples_b`) metric cost matrix + (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) @@ -2313,6 +3111,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Print information along iterations log : bool, optional record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. Returns ------- @@ -2334,17 +3134,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli References ---------- - .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative + Models with Sinkhorn Divergences, Proceedings of the Twenty-First + International Conference on Artficial Intelligence and Statistics, + (AISTATS) 21, 2018 ''' if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, - stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, + log=log, warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -2359,25 +3168,33 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli return max(0, sinkhorn_div), log else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, - verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, + warn=warn, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) -def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True, - maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False): +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, + restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, + verbose=False, log=False): r""" Screening Sinkhorn Algorithm for Regularized Optimal Transport - The function solves an approximate dual of Sinkhorn divergence :ref:`[2] ` which is written as the following optimization problem: + The function solves an approximate dual of Sinkhorn divergence :ref:`[2] + ` which is written as the following optimization problem: .. math:: @@ -2395,56 +3212,49 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} - The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] ` + The parameters `kappa` and `epsilon` are determined w.r.t the couple number + budget of points (`ns_budget`, `nt_budget`), see Equation (5) + in :ref:`[26] ` Parameters ---------- - a : array-like, shape=(ns,) + a: array-like, shape=(ns,) samples weights in the source domain - - b : array-like, shape=(nt,) + b: array-like, shape=(nt,) samples weights in the target domain - - M : array-like, shape=(ns, nt) + M: array-like, shape=(ns, nt) Cost matrix - - reg : `float` + reg: `float` Level of the entropy regularisation - - ns_budget : `int`, default=None + ns_budget: `int`, default=None Number budget of points to be kept in the source domain. If it is None then 50% of the source sample points will be kept - - nt_budget : `int`, default=None + nt_budget: `int`, default=None Number budget of points to be kept in the target domain. If it is None then 50% of the target sample points will be kept - - uniform : `bool`, default=False - If `True`, the source and target distribution are supposed to be uniform, i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` - + uniform: `bool`, default=False + If `True`, the source and target distribution are supposed to be uniform, + i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` restricted : `bool`, default=True If `True`, a warm-start initialization for the L-BFGS-B solver using a restricted Sinkhorn algorithm with at most 5 iterations - - maxiter : `int`, default=10000 + maxiter: `int`, default=10000 Maximum number of iterations in LBFGS solver - - maxfun : `int`, default=10000 + maxfun: `int`, default=10000 Maximum number of function evaluations in LBFGS solver - - pgtol : `float`, default=1e-09 + pgtol: `float`, default=1e-09 Final objective function accuracy in LBFGS solver - - verbose : `bool`, default=False - If `True`, display informations about the cardinals of the active sets and the parameters kappa - and epsilon - + verbose: `bool`, default=False + If `True`, display informations about the cardinals of the active sets + and the parameters kappa and epsilon Dependency ---------- - To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/) - in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears: + To gain more efficiency, screenkhorn needs to call the "Bottleneck" + package (https://pypi.org/project/Bottleneck/) + in the screening pre-processing step. If Bottleneck isn't installed, + the following error message appears: "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" @@ -2461,9 +3271,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res References ----------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 - .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 + .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). + Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 """ # check if bottleneck module exists @@ -2471,14 +3283,16 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res import bottleneck except ImportError: warnings.warn( - "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") + "Bottleneck module is not installed. Install it from" + " https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received but screenkhorn is not compatible with JAX.") + raise TypeError("JAX arrays have been received but screenkhorn is not " + "compatible with JAX.") ns, nt = M.shape @@ -2582,7 +3396,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res if verbose: print("epsilon = %s\n" % epsilon) print("kappa = %s\n" % kappa) - print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel))) + print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' + % (sum(Isel), sum(Jsel))) # Ic, Jc: complementary of the active sets I and J Ic = ~Isel @@ -2638,13 +3453,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa - cpt = 1 - while cpt < 5: # 5 iterations + for _ in range(5): # 5 iterations K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v v0 = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, v0) + cst_u u0 = (kappa * a_I) / KIJ_u - cpt += 1 u0 = projection(u0, epsilon / kappa) v0 = projection(v0, epsilon * kappa) @@ -2655,15 +3468,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res def restricted_sinkhorn(usc, vsc, max_iter=5): """ - Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26]) + Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B) """ - cpt = 1 - while cpt < max_iter: + for _ in range(max_iter): K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v vsc = b_J / (kappa * K_IJ_v) KIJ_u = nx.dot(K_IJ, vsc) + cst_u usc = (kappa * a_I) / KIJ_u - cpt += 1 usc = projection(usc, epsilon / kappa) vsc = projection(vsc, epsilon * kappa) diff --git a/test/test_bregman.py b/test/test_bregman.py index 6923d31..edfe9c3 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -6,6 +6,8 @@ # # License: MIT License +from itertools import product + import numpy as np import pytest @@ -13,7 +15,8 @@ import ot from ot.backend import torch -def test_sinkhorn(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn(verbose, warn): # test sinkhorn n = 100 rng = np.random.RandomState(0) @@ -23,7 +26,7 @@ def test_sinkhorn(): M = ot.dist(x, x) - G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) + G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn) # check constraints np.testing.assert_allclose( @@ -31,8 +34,92 @@ def test_sinkhorn(): np.testing.assert_allclose( u, G.sum(0), atol=1e-05) # cf convergence sinkhorn + with pytest.warns(UserWarning): + ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_epsilon_scaling", + "greenkhorn", + "sinkhorn_log"]) +def test_convergence_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + A = np.asarray([a1, a2]).T + M = ot.utils.dist0(n) + + with pytest.warns(UserWarning): + ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1) + + if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]: + with pytest.warns(UserWarning): + ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) + with pytest.warns(UserWarning): + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + + +def test_not_impemented_method(): + # test sinkhorn + w = 10 + n = w ** 2 + rng = np.random.RandomState(42) + A_img = rng.rand(2, w, w) + A_flat = A_img.reshape(n, 2) + a1, a2 = A_flat.T + M_flat = ot.utils.dist0(n) + not_implemented = "new_method" + reg = 0.01 + with pytest.raises(ValueError): + ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.barycenter(A_flat, M_flat, reg, method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.barycenter_debiased(A_flat, M_flat, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d(A_img, reg, + method=not_implemented) + with pytest.raises(ValueError): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, + method=not_implemented) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) +def test_nan_warning(method): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + + M = ot.utils.dist0(n) + reg = 0 + with pytest.warns(UserWarning): + # warn set to False to avoid catching a convergence warning instead + ot.sinkhorn(a1, a2, M, reg, method=method, warn=False) + + +def test_sinkhorn_stabilization(): + # test sinkhorn + n = 100 + a1 = ot.datasets.make_1D_gauss(n, m=30, s=10) + a2 = ot.datasets.make_1D_gauss(n, m=40, s=10) + M = ot.utils.dist0(n) + reg = 1e-5 + loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log") + loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized") + np.testing.assert_allclose( + loss1, loss2, atol=1e-06) # cf convergence sinkhorn + -def test_sinkhorn_multi_b(): +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", + "sinkhorn_log"], + [True, False], [True, False])) +def test_sinkhorn_multi_b(method, verbose, warn): # test sinkhorn n = 10 rng = np.random.RandomState(0) @@ -45,12 +132,14 @@ def test_sinkhorn_multi_b(): M = ot.dist(x, x) - loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True) + loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10, + log=True) - loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)] + loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10, + verbose=verbose, warn=warn) for k in range(3)] # check constraints np.testing.assert_allclose( - loss0, loss, atol=1e-06) # cf convergence sinkhorn + loss0, loss, atol=1e-4) # cf convergence sinkhorn def test_sinkhorn_backends(nx): @@ -67,9 +156,9 @@ def test_sinkhorn_backends(nx): G = ot.sinkhorn(a, a, M, 1) ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) - Gb = ot.sinkhorn(ab, ab, Mb, 1) + Gb = ot.sinkhorn(ab, ab, M_nx, 1) np.allclose(G, nx.to_numpy(Gb)) @@ -88,9 +177,9 @@ def test_sinkhorn2_backends(nx): G = ot.sinkhorn(a, a, M, 1) ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) - Gb = ot.sinkhorn2(ab, ab, Mb, 1) + Gb = ot.sinkhorn2(ab, ab, M_nx, 1) np.allclose(G, nx.to_numpy(Gb)) @@ -131,6 +220,12 @@ def test_sinkhorn_empty(): M = ot.dist(x, x) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log", + verbose=True, log=True) + # check constraints + np.testing.assert_allclose(u, G.sum(1), atol=1e-05) + np.testing.assert_allclose(u, G.sum(0), atol=1e-05) + G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True) # check constraints np.testing.assert_allclose(u, G.sum(1), atol=1e-05) @@ -165,15 +260,15 @@ def test_sinkhorn_variants(nx): M = ot.dist(x, x) ub = nx.from_numpy(u) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) Ges = nx.to_numpy(ot.sinkhorn( - ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) - G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10)) + ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)) + G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -199,12 +294,12 @@ def test_sinkhorn_variants_multi_b(nx): ub = nx.from_numpy(u) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -228,12 +323,12 @@ def test_sinkhorn2_variants_multi_b(nx): ub = nx.from_numpy(u) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10) - Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10)) - G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10)) - Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10)) + Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10)) + G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10)) + Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10)) # check values np.testing.assert_allclose(G, G0, atol=1e-05) @@ -255,7 +350,7 @@ def test_sinkhorn_variants_log(): Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) Ges, loges = ot.sinkhorn( - u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True) + u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,) G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True) # check values @@ -265,7 +360,8 @@ def test_sinkhorn_variants_log(): np.testing.assert_allclose(G0, G_green, atol=1e-5) -def test_sinkhorn_variants_log_multib(): +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_sinkhorn_variants_log_multib(verbose, warn): # test sinkhorn n = 50 rng = np.random.RandomState(0) @@ -278,16 +374,20 @@ def test_sinkhorn_variants_log_multib(): M = ot.dist(x, x) G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True) - Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True) - Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True) + Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) + Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True, + verbose=verbose, warn=warn) # check values np.testing.assert_allclose(G0, Gs, atol=1e-05) np.testing.assert_allclose(G0, Gl, atol=1e-05) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"]) -def test_barycenter(nx, method): +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter(nx, method, verbose, warn): n_bins = 100 # nb bins # Gaussian distributions @@ -304,20 +404,98 @@ def test_barycenter(nx, method): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - Ab = nx.from_numpy(A) - Mb = nx.from_numpy(M) - weightsb = nx.from_numpy(weights) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) + reg = 1e-2 + + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method) + else: + # wasserstein + bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass)) + np.testing.assert_allclose(bary_wass, bary_wass_np) + + ot.bregman.barycenter(A_nx, M_nx, reg, log=True) + + +@pytest.mark.parametrize("method, verbose, warn", + product(["sinkhorn", "sinkhorn_log"], + [True, False], [True, False])) +def test_barycenter_debiased(nx, method, verbose, warn): + n_bins = 100 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) + weights_nx = nx.from_numpy(weights) # wasserstein reg = 1e-2 - bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True) - bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True) - bary_wass = nx.to_numpy(bary_wass) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method) + else: + bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, + verbose=verbose, warn=warn) + bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True) + bary_wass = nx.to_numpy(bary_wass) + + np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5) + + ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False) - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) - ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_convergence_warning_barycenters(method): + w = 10 + n_bins = w ** 2 # nb bins + + # Gaussian distributions + a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std + a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10) + + # creating matrix A containing all distributions + A = np.vstack((a1, a2)).T + A_img = A.reshape(2, w, w) + A_img /= A_img.sum((1, 2))[:, None, None] + + # loss matrix + normalization + M = ot.utils.dist0(n_bins) + M /= M.max() + + alpha = 0.5 # 0<=alpha<=1 + weights = np.array([1 - alpha, alpha]) + reg = 0.1 + with pytest.warns(UserWarning): + ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d(A_img, reg, weights, + method=method, numItermax=1) + with pytest.warns(UserWarning): + ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights, + method=method, numItermax=1) def test_barycenter_stabilization(nx): @@ -337,31 +515,64 @@ def test_barycenter_stabilization(nx): alpha = 0.5 # 0<=alpha<=1 weights = np.array([1 - alpha, alpha]) - Ab = nx.from_numpy(A) - Mb = nx.from_numpy(M) + A_nx = nx.from_numpy(A) + M_nx = nx.from_numpy(M) weights_b = nx.from_numpy(weights) # wasserstein reg = 1e-2 bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True) bar_stable = nx.to_numpy(ot.bregman.barycenter( - Ab, Mb, reg, weights_b, method="sinkhorn_stabilized", + A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized", stopThr=1e-8, verbose=True )) bar = nx.to_numpy(ot.bregman.barycenter( - Ab, Mb, reg, weights_b, method="sinkhorn", + A_nx, M_nx, reg, weights_b, method="sinkhorn", stopThr=1e-8, verbose=True )) np.testing.assert_allclose(bar, bar_stable) np.testing.assert_allclose(bar, bar_np) -def test_wasserstein_bary_2d(nx): - size = 100 # size of a square image - a1 = np.random.randn(size, size) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) + a1 += a1.min() + a1 = a1 / np.sum(a1) + a2 = np.random.rand(size, size) + a2 += a2.min() + a2 = a2 / np.sum(a2) + # creating matrix A containing all distributions + A = np.zeros((2, size, size)) + A[0, :, :] = a1 + A[1, :, :] = a2 + + A_nx = nx.from_numpy(A) + + # wasserstein + reg = 1e-2 + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)) + + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) + + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + + +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"]) +def test_wasserstein_bary_2d_debiased(nx, method): + size = 20 # size of a square image + a1 = np.random.rand(size, size) a1 += a1.min() a1 = a1 / np.sum(a1) - a2 = np.random.randn(size, size) + a2 = np.random.rand(size, size) a2 += a2.min() a2 = a2 / np.sum(a2) # creating matrix A containing all distributions @@ -369,18 +580,22 @@ def test_wasserstein_bary_2d(nx): A[0, :, :] = a1 A[1, :, :] = a2 - Ab = nx.from_numpy(A) + A_nx = nx.from_numpy(A) # wasserstein reg = 1e-2 - bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg) - bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg)) + if nx.__name__ == "jax" and method == "sinkhorn_log": + with pytest.raises(NotImplementedError): + ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method) + else: + bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method) + bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)) - np.testing.assert_allclose(1, np.sum(bary_wass)) - np.testing.assert_allclose(bary_wass, bary_wass_np) + np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3) + np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3) - # help in checking if log and verbose do not bug the function - ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) + # help in checking if log and verbose do not bug the function + ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True) def test_unmix(nx): @@ -405,20 +620,20 @@ def test_unmix(nx): ab = nx.from_numpy(a) Db = nx.from_numpy(D) - Mb = nx.from_numpy(M) + M_nx = nx.from_numpy(M) M0b = nx.from_numpy(M0) h0b = nx.from_numpy(h0) # wasserstein reg = 1e-3 um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01) - um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01)) + um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01)) np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03) np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03) np.testing.assert_allclose(um, um_np) - ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, + ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01, log=True, verbose=True) @@ -437,22 +652,22 @@ def test_empirical_sinkhorn(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_mb = nx.from_numpy(M_m, type_as=ab) G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1)) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True) G_log = nx.to_numpy(G_log) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean')) sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1)) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( @@ -486,18 +701,18 @@ def test_lazy_empirical_sinkhorn(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_mb = nx.from_numpy(M_m, type_as=ab) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) - sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1)) + sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1) @@ -507,7 +722,7 @@ def test_lazy_empirical_sinkhorn(nx): loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) - loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1)) + loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) # check constraints np.testing.assert_allclose( @@ -541,13 +756,13 @@ def test_empirical_sinkhorn_divergence(nx): bb = nx.from_numpy(b) X_sb = nx.from_numpy(X_s) X_tb = nx.from_numpy(X_t) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) M_sb = nx.from_numpy(M_s, type_as=ab) M_tb = nx.from_numpy(M_t, type_as=ab) emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb)) sinkhorn_div = nx.to_numpy( - ot.sinkhorn2(ab, bb, Mb, 1) + ot.sinkhorn2(ab, bb, M_nx, 1) - 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1) - 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1) ) @@ -580,14 +795,14 @@ def test_stabilized_vs_sinkhorn_multidim(nx): ab = nx.from_numpy(a) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True) - G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon, + G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon, method="sinkhorn_stabilized", log=True) G = nx.to_numpy(G) - G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon, + G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon, method="sinkhorn", log=True) G2 = nx.to_numpy(G2) @@ -642,14 +857,14 @@ def test_screenkhorn(nx): ab = nx.from_numpy(a) bb = nx.from_numpy(b) - Mb = nx.from_numpy(M, type_as=ab) + M_nx = nx.from_numpy(M, type_as=ab) # np sinkhorn G_sink_np = ot.sinkhorn(a, b, M, 1e-03) # sinkhorn - G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03)) + G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03)) # screenkhorn - G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True)) + G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True)) # check marginals np.testing.assert_allclose(G_sink_np, G_sink) np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02) @@ -659,10 +874,10 @@ def test_screenkhorn(nx): def test_convolutional_barycenter_non_square(nx): # test for image with height not equal width A = np.ones((2, 2, 3)) / (2 * 3) - Ab = nx.from_numpy(A) + A_nx = nx.from_numpy(A) b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03) - b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03)) + b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03)) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02) -- cgit v1.2.3 From 2fe69eb130827560ada704bc25998397c4357821 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Thu, 4 Nov 2021 11:00:09 +0100 Subject: [MRG] Make gromov loss differentiable wrt matrices and weights (#302) * grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints --- README.md | 9 +- examples/backends/plot_optim_gromov_pytorch.py | 260 +++++++++++++++++++++++++ ot/__init__.py | 2 + ot/gromov.py | 141 +++++++++++--- ot/optim.py | 3 +- test/test_gromov.py | 76 ++++++++ 6 files changed, 460 insertions(+), 31 deletions(-) create mode 100644 examples/backends/plot_optim_gromov_pytorch.py (limited to 'examples') diff --git a/README.md b/README.md index ff32c53..08db003 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ POT provides the following generic OT solvers (links to examples): * Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). -* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) +* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] * [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] @@ -295,5 +295,8 @@ You can also post bug reports and feature requests in Github issues. Make sure t via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. -[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International -Conference on Machine Learning, PMLR 119:4692-4701, 2020 \ No newline at end of file +[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International +Conference on Machine Learning, PMLR 119:4692-4701, 2020 + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph +Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. \ No newline at end of file diff --git a/examples/backends/plot_optim_gromov_pytorch.py b/examples/backends/plot_optim_gromov_pytorch.py new file mode 100644 index 0000000..465f612 --- /dev/null +++ b/examples/backends/plot_optim_gromov_pytorch.py @@ -0,0 +1,260 @@ +r""" +================================= +Optimizing the Gromov-Wasserstein distance with PyTorch +================================= + +In this exemple we use the pytorch backend to optimize the Gromov-Wasserstein +(GW) loss between two graphs expressed as empirical distribution. + +In the first example we optimize the weights on the node of a simple template +graph so that it minimizes the GW with a given Stochastic Block Model graph. +We can see that this actually recovers the proportion of classes in the SBM +and allows for an accurate clustering of the nodes using the GW optimal plan. + +In a second example we optimize simultaneously the weights and the sructure of +the template graph which allows us to perform graph compression and to recover +other properties of the SBM. + +The backend actually uses the gradients expressed in [38] to optimize the +weights. + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph +Dictionary Learning, International Conference on Machine Learning (ICML), 2021. + +""" +# Author: Rémi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +from sklearn.manifold import MDS +import numpy as np +import matplotlib.pylab as pl +import torch + +import ot +from ot.gromov import gromov_wasserstein2 + +# %% +# Graph generation +# --------------- + +rng = np.random.RandomState(42) + + +def get_sbm(n, nc, ratio, P): + nbpc = np.round(n * ratio).astype(int) + n = np.sum(nbpc) + C = np.zeros((n, n)) + for c1 in range(nc): + for c2 in range(c1 + 1): + if c1 == c2: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): + for j in range(np.sum(nbpc[:c2]), i): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + else: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): + for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[:c2 + 1])): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + + return C + C.T + + +n = 100 +nc = 3 +ratio = np.array([.5, .3, .2]) +P = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3))) +C1 = get_sbm(n, nc, ratio, P) + +# get 2d position for nodes +x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1) + + +def plot_graph(x, C, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (10, 5)) +pl.clf() +pl.subplot(1, 2, 1) +plot_graph(x1, C1, color='C0') +pl.title("SBM Graph") +pl.axis("off") +pl.subplot(1, 2, 2) +pl.imshow(C1, interpolation='nearest') +pl.title("Adjacency matrix") +pl.axis("off") + + +# %% +# Optimizing the weights of a simple template C0=eye(3) to fit Graph 1 +# ------------------------------------------------ +# The adajacency matrix C1 is block diagonal with 3 blocks. We want to +# optimize the weights of a simple template C0=eye(3) and see if we can +# recover the proportion of classes from the SBM (up to a permutation). + +C0 = np.eye(3) + + +def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2): + """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + + # use pyTorch for our data + C1_torch = torch.tensor(C1) + C2_torch = torch.tensor(C2) + + a0 = rng.rand(C1.shape[0]) # random_init + a0 /= a0.sum() # on simplex + a1_torch = torch.tensor(a0).requires_grad_(True) + a2_torch = torch.tensor(a2) + + loss_iter = [] + + for i in range(nb_iter_max): + + loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + #print("{:03d} | {}".format(i, loss_iter[-1])) + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a1_torch.grad + a1_torch -= grad * lr # step + a1_torch.grad.zero_() + a1_torch.data = ot.utils.proj_simplex(a1_torch) + + a1 = a1_torch.clone().detach().cpu().numpy() + + return a1, loss_iter + + +a0_est, loss_iter0 = min_weight_gw(C0, C1, ot.unif(n), nb_iter_max=100, lr=1e-2) + +pl.figure(2) +pl.plot(loss_iter0) +pl.title("Loss along iterations") + +print("Estimated weights : ", a0_est) +print("True proportions : ", ratio) + + +# %% +# It is clear that the optimization has converged and that we recover the +# ratio of the different classes in the SBM graph up to a permutation. + + +# %% +# Community clustering with uniform and estimated weights +# -------------------------------------------- +# The GW OT plan can be used to perform a clustering of the nodes of a graph +# when computing the GW with a simple template like C0 by labeling nodes in +# the original graph using by the index of the noe in the template receiving +# the most mass. +# +# We show here the result of such a clustering when using uniform weights on +# the template C0 and when using the optimal weights previously estimated. + + +T_unif = ot.gromov_wasserstein(C1, C0, ot.unif(n), ot.unif(3)) +label_unif = T_unif.argmax(1) + +T_est = ot.gromov_wasserstein(C1, C0, ot.unif(n), a0_est) +label_est = T_est.argmax(1) + +pl.figure(3, (10, 5)) +pl.clf() +pl.subplot(1, 2, 1) +plot_graph(x1, C1, color=label_unif) +pl.title("Graph clustering unif. weights") +pl.axis("off") +pl.subplot(1, 2, 2) +plot_graph(x1, C1, color=label_est) +pl.title("Graph clustering est. weights") +pl.axis("off") + + +# %% +# Graph compression with GW +# ------------------------- + +# Now we optimize both the weights and structure of a small graph that +# minimize the GW distance wrt our data graph. This can be seen as graph +# compression but can also recover important properties of an SBM such +# as its class proportion but also its matrix of probability of links between +# classes + + +def graph_compession_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2): + """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + + # use pyTorch for our data + + C2_torch = torch.tensor(C2) + a2_torch = torch.tensor(a2) + + a0 = rng.rand(nb_nodes) # random_init + a0 /= a0.sum() # on simplex + a1_torch = torch.tensor(a0).requires_grad_(True) + C0 = np.eye(nb_nodes) + C1_torch = torch.tensor(C0).requires_grad_(True) + + loss_iter = [] + + for i in range(nb_iter_max): + + loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + #print("{:03d} | {}".format(i, loss_iter[-1])) + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a1_torch.grad + a1_torch -= grad * lr # step + a1_torch.grad.zero_() + a1_torch.data = ot.utils.proj_simplex(a1_torch) + + grad = C1_torch.grad + C1_torch -= grad * lr # step + C1_torch.grad.zero_() + C1_torch.data = torch.clamp(C1_torch, 0, 1) + + a1 = a1_torch.clone().detach().cpu().numpy() + C1 = C1_torch.clone().detach().cpu().numpy() + + return a1, C1, loss_iter + + +nb_nodes = 3 +a0_est2, C0_est2, loss_iter2 = graph_compession_gw(nb_nodes, C1, ot.unif(n), + nb_iter_max=100, lr=5e-2) + +pl.figure(4) +pl.plot(loss_iter2) +pl.title("Loss along iterations") + + +print("Estimated weights : ", a0_est2) +print("True proportions : ", ratio) + +pl.figure(6, (10, 3.5)) +pl.clf() +pl.subplot(1, 2, 1) +pl.imshow(P, vmin=0, vmax=1) +pl.title('True SBM P matrix') +pl.subplot(1, 2, 2) +pl.imshow(C0_est2, vmin=0, vmax=1) +pl.title('Estimated C0 matrix') +pl.colorbar() diff --git a/ot/__init__.py b/ot/__init__.py index f20332c..4292b41 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -43,6 +43,8 @@ from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .gromov import (gromov_wasserstein, gromov_wasserstein2, + gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) # utils functions from .utils import dist, unif, tic, toc, toq diff --git a/ot/gromov.py b/ot/gromov.py index 465693d..ea667e4 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -174,7 +174,7 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): - """Return the Loss for Gromov-Wasserstein + r"""Return the Loss for Gromov-Wasserstein The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` @@ -213,7 +213,7 @@ def gwloss(constC, hC1, hC2, T): def gwggrad(constC, hC1, hC2, T): - """Return the gradient for Gromov-Wasserstein + r"""Return the gradient for Gromov-Wasserstein The gradient is computed as described in Proposition 2 in :ref:`[12] ` @@ -247,7 +247,7 @@ def gwggrad(constC, hC1, hC2, T): def update_square_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -284,7 +284,7 @@ def update_square_loss(p, lambdas, T, Cs): def update_kl_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -320,7 +320,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -386,6 +386,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -398,13 +406,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs if log: res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = gwloss(constC, hC1, hC2, res) - return res, log + log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10) + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log else: - return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -420,7 +430,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - :math:`\mathbf{p}`: distribution in the source space - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -463,9 +477,21 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487. + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -475,13 +501,28 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg def df(G): return gwggrad(constC, hC1, hC2, G) - res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res) - log_gw['T'] = res + + T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + + T0 = nx.from_numpy(T, type_as=C10) + + log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10) + log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10) + log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10) + log_gw['T'] = T0 + + gw = log_gw['gw_dist'] + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gw = nx.set_gradients(gw, (p0, q0, C10, C20), + (log_gw['u'], log_gw['v'], gC1, gC2)) + if log: - return log_gw['gw_dist'], log_gw + return gw, log_gw else: - return log_gw['gw_dist'] + return gw def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -548,6 +589,15 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -560,10 +610,16 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - log['fgw_dist'] = log['loss'][::-1][0] - return res, log + + fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) + + log['fgw_dist'] = fgw_dist + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: - return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -586,7 +642,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[24] ` + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -627,9 +687,22 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. + + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -640,13 +713,27 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + + fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10) + + T0 = nx.from_numpy(T, type_as=C10) + + log_fgw['fgw_dist'] = fgw_dist + log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10) + log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10) + log_fgw['T'] = T0 + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), + (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) + if log: - log['fgw_dist'] = log['loss'][::-1][0] - log['T'] = res - return log['fgw_dist'], log + return fgw_dist, log_fgw else: - return log['fgw_dist'] + return fgw_dist def GW_distance_estimation(C1, C2, p, q, loss_fun, T, @@ -1447,7 +1534,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None): - """Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` + r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` Parameters ---------- @@ -1604,7 +1691,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ def update_structure_matrix(p, lambdas, T, Cs): - """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. It is calculated at each iteration @@ -1640,7 +1727,7 @@ def update_structure_matrix(p, lambdas, T, Cs): def update_feature_matrix(lambdas, Ys, Ts, p): - """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" diff --git a/ot/optim.py b/ot/optim.py index cc286b6..bd8ca26 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -267,7 +267,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, Mi += nx.min(Mi) # solve linear program - Gc = emd(a, b, Mi, numItermax=numItermaxEmd) + Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) deltaG = Gc - G @@ -297,6 +297,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: + log.update(logemd) return G, log else: return G diff --git a/test/test_gromov.py b/test/test_gromov.py index 509c54d..bcbcc3a 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -9,6 +9,7 @@ import numpy as np import ot from ot.backend import NumpyBackend +from ot.backend import torch import pytest @@ -74,6 +75,42 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_gromov2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + + val = ot.gromov_wasserstein2(C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + @pytest.skip_backend("jax", reason="test very slow with jax backend") def test_entropic_gromov(nx): n_samples = 50 # nb samples @@ -389,6 +426,45 @@ def test_fgw(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_fgw2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + def test_fgw_barycenter(nx): np.random.seed(42) -- cgit v1.2.3 From cec41d3817067a2eb3031092735347efe4184237 Mon Sep 17 00:00:00 2001 From: Rémi Flamary Date: Fri, 5 Nov 2021 17:13:14 +0100 Subject: [MRG] Release 0.8 (#289) * working on release * test circleci * try again * cleanup circle ci run * add all PR and releant Issues * update doc * thanks idris * update version + add pyproject.toml * test pyproject.toml * revert tests * build wheels * use windows-latest for tests * add tests python 3.10 * build all whels * all versions * build all wheels * build all wheels * cleanup pep8 and minimal acions * forst shot text release * bettr text * stuff * release text updated * update manifest to allow build from source * update doc again * update release --- .github/workflows/build_tests.yml | 26 ++-- .github/workflows/build_wheels.yml | 13 +- .github/workflows/build_wheels_weekly.yml | 5 +- MANIFEST.in | 2 + RELEASES.md | 192 ++++++++++++++++++++++--- docs/source/readme.rst | 28 +++- docs/source/releases.rst | 134 ++++++++++++++++- examples/backends/plot_optim_gromov_pytorch.py | 2 +- ot/__init__.py | 8 +- ot/gpu/__init__.py | 10 +- ot/helpers/__init__.py | 3 + 11 files changed, 365 insertions(+), 58 deletions(-) create mode 100644 ot/helpers/__init__.py (limited to 'examples') diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index 91631b4..ee5a435 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -22,7 +22,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [ 3.6, 3.7, 3.8, 3.9] + python-version: [ "3.6", "3.7", "3.8", "3.9"] steps: - uses: actions/checkout@v1 @@ -48,17 +48,12 @@ jobs: pep8: runs-on: ubuntu-latest if: "!contains(github.event.head_commit.message, 'no pep8')" - strategy: - max-parallel: 4 - matrix: - python-version: [3.8] - steps: - uses: actions/checkout@v1 - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python uses: actions/setup-python@v1 with: - python-version: ${{ matrix.python-version }} + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -74,17 +69,12 @@ jobs: runs-on: ubuntu-latest if: "!contains(github.event.head_commit.message, 'no ci')" - strategy: - max-parallel: 4 - matrix: - python-version: [3.8] - steps: - uses: actions/checkout@v1 - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python uses: actions/setup-python@v1 with: - python-version: ${{ matrix.python-version }} + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -103,7 +93,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.7, 3.8, 3.9] + python-version: ["3.7", "3.8", "3.9"] steps: - uses: actions/checkout@v1 @@ -125,12 +115,12 @@ jobs: windows: - runs-on: windows-2019 + runs-on: windows-latest if: "!contains(github.event.head_commit.message, 'no ci')" strategy: max-parallel: 4 matrix: - python-version: [3.7, 3.8, 3.9] + python-version: ["3.7", "3.8", "3.9"] steps: - uses: actions/checkout@v1 diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml index 53246ce..a935a5e 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/build_wheels.yml @@ -3,6 +3,7 @@ name: Build wheels on: workflow_dispatch: release: + pull_request: push: branches: - "*" @@ -31,7 +32,7 @@ jobs: - name: Install cibuildwheel run: | - python -m pip install cibuildwheel==1.10.0 + python -m pip install cibuildwheel==2.2.2 - name: Build wheels env: @@ -69,12 +70,7 @@ jobs: - name: Install cibuildwheel run: | - python -m pip install cibuildwheel==1.10.0 - - - name: Install Visual C++ for Python 2.7 - if: startsWith(matrix.os, 'windows') - run: | - choco install vcpython27 -f -y + python -m pip install cibuildwheel==2.2.2 - name: Set up QEMU if: runner.os == 'Linux' @@ -84,9 +80,10 @@ jobs: - name: Build wheels env: - CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp*" # remove pypy on mac and win (wrong version) + CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl*" # remove pypy on mac and win (wrong version) CIBW_BEFORE_BUILD: "pip install numpy cython" CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU + CIBW_ARCHS_MACOS: x86_64 universal2 arm64 run: | python -m cibuildwheel --output-dir wheelhouse diff --git a/.github/workflows/build_wheels_weekly.yml b/.github/workflows/build_wheels_weekly.yml index 32b697f..2964844 100644 --- a/.github/workflows/build_wheels_weekly.yml +++ b/.github/workflows/build_wheels_weekly.yml @@ -31,7 +31,7 @@ jobs: - name: Install cibuildwheel run: | - python -m pip install cibuildwheel==1.10.0 + python -m pip install cibuildwheel==2.2.2 - name: Set up QEMU if: runner.os == 'Linux' @@ -41,9 +41,10 @@ jobs: - name: Build wheels env: - CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp*" # remove pypy on mac and win (wrong version) + CIBW_SKIP: "pp*-win* pp*-macosx* cp2* pp* cp*musl*" # remove pypy on mac and win (wrong version) CIBW_BEFORE_BUILD: "pip install numpy cython" CIBW_ARCHS_LINUX: auto aarch64 # force aarch64 with QEMU + CIBW_ARCHS_MACOS: x86_64 universal2 arm64 run: | python -m cibuildwheel --output-dir wheelhouse diff --git a/MANIFEST.in b/MANIFEST.in index df4e139..da67c77 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -6,4 +6,6 @@ include ot/lp/EMD.h include ot/lp/EMD_wrapper.cpp include ot/lp/emd_wrap.pyx include ot/lp/full_bipartitegraph.h +include ot/lp/full_bipartitegraph_omp.h include ot/lp/network_simplex_simple.h +include ot/lp/network_simplex_simple_omp.h diff --git a/RELEASES.md b/RELEASES.md index adb7fc1..6eb1502 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,19 +1,167 @@ # Releases -## 0.7.0 -*May 2020* -This is the new stable release for POT. We made a lot of changes in the documentation and added several new features such as Partial OT, Unbalanced and Multi Sources OT Domain Adaptation and several bug fixes. One important change is that we have created the GitHub organization [PythonOT](https://github.com/PythonOT) that now owns the main POT repository [https://github.com/PythonOT/POT](https://github.com/PythonOT/POT) and the repository for the new documentation is now hosted at [https://PythonOT.github.io/](https://PythonOT.github.io/). +## 0.8.0 +*November 2021* + +This new stable release introduces several important features. + +First we now have +an OpenMP compatible exact ot solver in `ot.emd`. The OpenMP version is used +when the parameter `numThreads` is greater than one and can lead to nice +speedups on multi-core machines. + +Second we have introduced a backend mechanism that allows to use standard POT +function seamlessly on Numpy, Pytorch and Jax arrays. Other backends are coming +but right now POT can be used seamlessly for training neural networks in +Pytorch. Notably we propose the first differentiable computation of the exact OT +loss with `ot.emd2` (can be differentiated w.r.t. both cost matrix and sample +weights), but also for the classical Sinkhorn loss with `ot.sinkhorn2`, the +Wasserstein distance in 1D with `ot.wasserstein_1d`, sliced Wasserstein with +`ot.sliced_wasserstein_distance` and Gromov-Wasserstein with `ot.gromov_wasserstein2`. Examples of how +this new feature can be used are now available in the documentation where the +Pytorch backend is used to estimate a [minimal Wasserstein +estimator](https://PythonOT.github.io/auto_examples/backends/plot_unmix_optim_torch.html), +a [Generative Network +(GAN)](https://PythonOT.github.io/auto_examples/backends/plot_wass2_gan_torch.html), +for a [sliced Wasserstein gradient +flow](https://PythonOT.github.io/auto_examples/backends/plot_sliced_wass_grad_flow_pytorch.html) +and [optimizing the Gromov-Wassersein distance](https://PythonOT.github.io/auto_examples/backends/plot_optim_gromov_pytorch.html). Note that the Jax backend is still in early development and quite +slow at the moment, we strongly recommend for Jax users to use the [OTT +toolbox](https://github.com/google-research/ott) when possible. + As a result of this new feature, + the old `ot.gpu` submodule is now deprecated since GPU +implementations can be done using GPU arrays on the torch backends. + +Other novel features include implementation for [Sampled Gromov Wasserstein and +Pointwise Gromov +Wasserstein](https://PythonOT.github.io/auto_examples/gromov/plot_gromov.html#compute-gw-with-a-scalable-stochastic-method-with-any-loss-function), +Sinkhorn in log space with `method='sinkhorn_log'`, [Projection Robust +Wasserstein](https://PythonOT.github.io/gen_modules/ot.dr.html?highlight=robust#ot.dr.projection_robust_wasserstein), +ans [deviased Sinkorn barycenters](https://PythonOT.github.ioauto_examples/barycenters/plot_debiased_barycenter.html). + +This release will also simplify the installation process. We have now a +`pyproject.toml` that defines the build dependency and POT should now build even +when cython is not installed yet. Also we now provide pe-compiled wheels for +linux `aarch64` that is used on Raspberry PI and android phones and for MacOS on +ARM processors. + + +Finally POT was accepted for publication in the Journal of Machine Learning +Research (JMLR) open source software track and we ask the POT users to cite [this +paper](https://www.jmlr.org/papers/v22/20-451.html) from now on. The documentation has been improved in particular by adding a +"Why OT?" section to the quick start guide and several new examples illustrating +the new features. The documentation now has two version : the stable version +[https://pythonot.github.io/](https://pythonot.github.io/) +corresponding to the last release and the master version [https://pythonot.github.io/master](https://pythonot.github.io/master) that corresponds to the +current master branch on GitHub. + + +As usual, we want to thank all the POT contributors (now 37 people have +contributed to the toolbox). But for this release we thank in particular Nathan +Cassereau and Kamel Guerda from the AI support team at +[IDRIS](http://www.idris.fr/) for their support to the development of the +backend and OpenMP implementations. + + +#### New features + +- OpenMP support for exact OT solvers (PR #260) +- Backend for running POT in numpy/torch + exact solver (PR #249) +- Backend implementation of most functions in `ot.bregman` (PR #280) +- Backend implementation of most functions in `ot.optim` (PR #282) +- Backend implementation of most functions in `ot.gromov` (PR #294, PR #302) +- Test for arrays of different type and device (CPU/GPU) (PR #304, #303) +- Implementation of Sinkhorn in log space with `method='sinkhorn_log'` (PR #290) +- Implementation of regularization path for L2 Unbalanced OT (PR #274) +- Implementation of Projection Robust Wasserstein (PR #267) +- Implementation of Debiased Sinkhorn Barycenters (PR #291) +- Implementation of Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein + (PR #275) +- Add `pyproject.toml` and build POT without installing cython first (PR #293) +- Lazy implementation in log space for sinkhorn on samples (PR #259) +- Documentation cleanup (PR #298) +- Two up-to-date documentations [for stable + release](https://PythonOT.github.io/) and for [master branch](https://pythonot.github.io/master/). +- Building wheels on ARM for Raspberry PI and smartphones (PR #238) +- Update build wheels to new version and new pythons (PR #236, #253) +- Implementation of sliced Wasserstein distance (Issue #202, PR #203) +- Add minimal build to CI and perform pep8 test separately (PR #210) +- Speedup of tests and return run time (PR #262) +- Add "Why OT" discussion to the documentation (PR #220) +- New introductory example to discrete OT in the documentation (PR #191) +- Add templates for Issues/PR on Github (PR#181) -This is the first release where the Python 2.7 tests have been removed. Most of the toolbox should still work but we do not offer support for Python 2.7 and will close related Issues. +#### Closed issues -A lot of changes have been done to the documentation that is now hosted on [https://PythonOT.github.io/](https://PythonOT.github.io/) instead of readthedocs. It was a hard choice but readthedocs did not allow us to run sphinx-gallery to update our beautiful examples and it was a huge amount of work to maintain. The documentation is now automatically compiled and updated on merge. We also removed the notebooks from the repository for space reason and also because they are all available in the [example gallery](https://pythonot.github.io/auto_examples/index.html). Note that now the output of the documentation build for each commit in the PR is available to check that the doc builds correctly before merging which was not possible with readthedocs. +- Debug Memory leak in GAN example (#254) +- DEbug GPU bug (Issue #284, #287, PR #288) +- set_gradients method for JAX backend (PR #278) +- Quicker GAN example for CircleCI build (PR #258) +- Better formatting in Readme (PR #234) +- Debug CI tests (PR #240, #241, #242) +- Bug in Partial OT solver dummy points (PR #215) +- Bug when Armijo linesearch (Issue #184, #198, #281, PR #189, #199, #286) +- Bug Barycenter Sinkhorn (Issue 134, PR #195) +- Infeasible solution in exact OT (Issues #126,#93, PR #217) +- Doc for SUpport Barycenters (Issue #200, PR #201) +- Fix labels transport in BaseTransport (Issue #207, PR #208) +- Bug in `emd_1d`, non respected bounds (Issue #169, PR #170) +- Removed Python 2.7 support and update codecov file (PR #178) +- Add normalization for WDA and test it (PR #172, #296) +- Cleanup code for new version of `flake8` (PR #176) +- Fixed requirements in `setup.py` (PR #174) +- Removed specific MacOS flags (PR #175) -The CI framework has also been changed with a move from Travis to Github Action which allows to get faster tests on Windows, MacOS and Linux. We also now report our coverage on [Codecov.io](https://codecov.io/gh/PythonOT/POT) and we have a reasonable 92% coverage. We also now generate wheels for a number of OS and Python versions at each merge in the master branch. They are available as outputs of this [action](https://github.com/PythonOT/POT/actions?query=workflow%3A%22Build+dist+and+wheels%22). This will allow simpler multi-platform releases from now on. -In terms of new features we now have [OTDA Classes for unbalanced OT](https://pythonot.github.io/gen_modules/ot.da.html#ot.da.UnbalancedSinkhornTransport), a new Domain adaptation class form [multi domain problems (JCPOT)](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html#sphx-glr-auto-examples-domain-adaptation-plot-otda-jcpot-py), and several solvers to solve the [Partial Optimal Transport](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html#sphx-glr-auto-examples-unbalanced-partial-plot-partial-wass-and-gromov-py) problems. +## 0.7.0 +*May 2020* -This release is also the moment to thank all the POT contributors (old and new) for helping making POT such a nice toolbox. A lot of changes (also in the API) are comming for the next versions. +This is the new stable release for POT. We made a lot of changes in the +documentation and added several new features such as Partial OT, Unbalanced and +Multi Sources OT Domain Adaptation and several bug fixes. One important change +is that we have created the GitHub organization +[PythonOT](https://github.com/PythonOT) that now owns the main POT repository +[https://github.com/PythonOT/POT](https://github.com/PythonOT/POT) and the +repository for the new documentation is now hosted at +[https://PythonOT.github.io/](https://PythonOT.github.io/). + +This is the first release where the Python 2.7 tests have been removed. Most of +the toolbox should still work but we do not offer support for Python 2.7 and +will close related Issues. + +A lot of changes have been done to the documentation that is now hosted on +[https://PythonOT.github.io/](https://PythonOT.github.io/) instead of +readthedocs. It was a hard choice but readthedocs did not allow us to run +sphinx-gallery to update our beautiful examples and it was a huge amount of work +to maintain. The documentation is now automatically compiled and updated on +merge. We also removed the notebooks from the repository for space reason and +also because they are all available in the [example +gallery](https://pythonot.github.io/auto_examples/index.html). Note that now the +output of the documentation build for each commit in the PR is available to +check that the doc builds correctly before merging which was not possible with +readthedocs. + +The CI framework has also been changed with a move from Travis to Github Action +which allows to get faster tests on Windows, MacOS and Linux. We also now report +our coverage on [Codecov.io](https://codecov.io/gh/PythonOT/POT) and we have a +reasonable 92% coverage. We also now generate wheels for a number of OS and +Python versions at each merge in the master branch. They are available as +outputs of this +[action](https://github.com/PythonOT/POT/actions?query=workflow%3A%22Build+dist+and+wheels%22). +This will allow simpler multi-platform releases from now on. + +In terms of new features we now have [OTDA Classes for unbalanced +OT](https://pythonot.github.io/gen_modules/ot.da.html#ot.da.UnbalancedSinkhornTransport), +a new Domain adaptation class form [multi domain problems +(JCPOT)](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html#sphx-glr-auto-examples-domain-adaptation-plot-otda-jcpot-py), +and several solvers to solve the [Partial Optimal +Transport](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html#sphx-glr-auto-examples-unbalanced-partial-plot-partial-wass-and-gromov-py) +problems. + +This release is also the moment to thank all the POT contributors (old and new) +for helping making POT such a nice toolbox. A lot of changes (also in the API) +are coming for the next versions. #### Features @@ -31,6 +179,8 @@ This release is also the moment to thank all the POT contributors (old and new) #### Closed issues +- Add JMLR paper to teh readme ad Mathieu Blondel to the Acknoledgments (PR + #231, #232) - Bug in Unbalanced OT example (Issue #127) - Clean Cython output when calling setup.py clean (Issue #122) - Various Macosx compilation problems (Issue #113, Issue #118, PR#130) @@ -54,18 +204,24 @@ https://python3statement.org/ for more reasons). For next release we will keep the travis tests for Python 2 but will make them non necessary for merge in 2020. The features are never complete in a toolbox designed for solving mathematical -problems and research but with the new contributions we now implement algorithms and solvers -from 24 scientific papers (listed in the README.md file). New features include a -direct implementation of the [empirical Sinkhorn divergence](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.empirical_sinkhorn_divergence) -, a new efficient (Cython implementation) solver for [EMD in 1D](https://pot.readthedocs.io/en/latest/all.html#ot.lp.emd_1d) -and corresponding [Wasserstein -1D](https://pot.readthedocs.io/en/latest/all.html#ot.lp.wasserstein_1d). We now also -have implementations for [Unbalanced OT](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_1D.ipynb) -and a solver for [Unbalanced OT barycenters](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_barycenter_1D.ipynb). +problems and research but with the new contributions we now implement algorithms +and solvers from 24 scientific papers (listed in the README.md file). New +features include a direct implementation of the [empirical Sinkhorn +divergence](https://pot.readthedocs.io/en/latest/all.html#ot.bregman.empirical_sinkhorn_divergence), +a new efficient (Cython implementation) solver for [EMD in +1D](https://pot.readthedocs.io/en/latest/all.html#ot.lp.emd_1d) and +corresponding [Wasserstein +1D](https://pot.readthedocs.io/en/latest/all.html#ot.lp.wasserstein_1d). We now +also have implementations for [Unbalanced +OT](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_1D.ipynb) and +a solver for [Unbalanced OT +barycenters](https://github.com/rflamary/POT/blob/master/notebooks/plot_UOT_barycenter_1D.ipynb). A new variant of Gromov-Wasserstein divergence called [Fused Gromov-Wasserstein](https://pot.readthedocs.io/en/latest/all.html?highlight=fused_#ot.gromov.fused_gromov_wasserstein) - has been also contributed with exemples of use on [structured data](https://github.com/rflamary/POT/blob/master/notebooks/plot_fgw.ipynb) -and computing [barycenters of labeld graphs](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb). +has been also contributed with exemples of use on [structured +data](https://github.com/rflamary/POT/blob/master/notebooks/plot_fgw.ipynb) and +computing [barycenters of labeld +graphs](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb). A lot of work has been done on the documentation with several new diff --git a/docs/source/readme.rst b/docs/source/readme.rst index ee32e2b..a8f1bc0 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -34,6 +34,9 @@ POT provides the following generic OT solvers (links to examples): [21] and unmixing [4]. - Sinkhorn divergence [23] and entropic regularization OT from empirical data. +- Debiased Sinkhorn barycenters `Sinkhorn divergence + barycenter `__ + [37] - `Smooth optimal transport solvers `__ (dual and semi-dual) for KL and squared L2 regularizations [17]. @@ -44,7 +47,8 @@ POT provides the following generic OT solvers (links to examples): distances `__ and `GW barycenters `__ - (exact [13] and regularized [12]) + (exact [13] and regularized [12]), differentiable using gradients + from - `Fused-Gromov-Wasserstein distances solver `__ and `FGW @@ -70,7 +74,8 @@ POT provides the following generic OT solvers (links to examples): (exact [29] and entropic [3] formulations). - `Sliced Wasserstein `__ - [31, 32]. + [31, 32] and Max-sliced Wasserstein [35] that can be used for + gradient flows [36]. - `Several backends `__ for easy use of POT with @@ -278,7 +283,8 @@ The contributors to this library are Rakotomamonjy `__ - `Vayer Titouan `__ (Gromov-Wasserstein -, Fused-Gromov-Wasserstein) -- `Hicham Janati `__ (Unbalanced OT) +- `Hicham Janati `__ (Unbalanced OT, + Debiased barycenters) - `Romain Tavenard `__ (1d Wasserstein) - `Mokhtar Z. Alaya `__ (Screenkhorn) - `Ievgen Redko `__ (Laplacian DA, JCPOT) @@ -501,6 +507,22 @@ gans `__. +In International Conference on Machine Learning (pp. 4104-4113). PMLR. + +[37] Janati, H., Cuturi, M., Gramfort, A. `Debiased sinkhorn +barycenters `__ +Proceedings of the 37th International Conference on Machine Learning, +PMLR 119:4692-4701, 2020 + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, +`Online Graph Dictionary +Learning `__, International +Conference on Machine Learning (ICML), 2021. + .. |PyPI version| image:: https://badge.fury.io/py/POT.svg :target: https://badge.fury.io/py/POT .. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg diff --git a/docs/source/releases.rst b/docs/source/releases.rst index 5a357f3..aa06105 100644 --- a/docs/source/releases.rst +++ b/docs/source/releases.rst @@ -1,6 +1,132 @@ Releases ======== +0.8.0 +----- + +*November 2021* + +This new stable release introduces several important features. + +First we now have an OpenMP compatible exact ot solver in ``ot.emd``. +The OpenMP version is used when the parameter ``numThreads`` is greater +than one and can lead to nice speedups on multi-core machines. + +| Second we have introduced a backend mechanism that allows to use + standard POT function seamlessly on Numpy, Pytorch and Jax arrays. + Other backends are coming but right now POT can be used seamlessly for + training neural networks in Pytorch. Notably we propose the first + differentiable computation of the exact OT loss with ``ot.emd2`` (can + be differentiated w.r.t. both cost matrix and sample weights), but + also for the classical Sinkhorn loss with ``ot.sinkhorn2``, the + Wasserstein distance in 1D with ``ot.wasserstein_1d``, sliced + Wasserstein with ``ot.sliced_wasserstein_distance`` and + Gromov-Wasserstein with ``ot.gromov_wasserstein2``. Examples of how + this new feature can be used are now available in the documentation + where the Pytorch backend is used to estimate a `minimal Wasserstein + estimator `__, + a `Generative Network + (GAN) `__, + for a `sliced Wasserstein gradient + flow `__ + and `optimizing the Gromov-Wassersein + distance `__. + Note that the Jax backend is still in early development and quite slow + at the moment, we strongly recommend for Jax users to use the `OTT + toolbox `__ when possible. +| As a result of this new feature, the old ``ot.gpu`` submodule is now + deprecated since GPU implementations can be done using GPU arrays on + the torch backends. + +Other novel features include implementation for `Sampled Gromov +Wasserstein and Pointwise Gromov +Wasserstein `__, +Sinkhorn in log space with ``method='sinkhorn_log'``, `Projection Robust +Wasserstein `__, +ans `deviased Sinkorn +barycenters `__. + +This release will also simplify the installation process. We have now a +``pyproject.toml`` that defines the build dependency and POT should now +build even when cython is not installed yet. Also we now provide +pe-compiled wheels for linux ``aarch64`` that is used on Raspberry PI +and android phones and for MacOS on ARM processors. + +Finally POT was accepted for publication in the Journal of Machine +Learning Research (JMLR) open source software track and we ask the POT +users to cite `this +paper `__ from now on. The +documentation has been improved in particular by adding a "Why OT?" +section to the quick start guide and several new examples illustrating +the new features. The documentation now has two version : the stable +version https://pythonot.github.io/ corresponding to the last release +and the master version https://pythonot.github.io/master that +corresponds to the current master branch on GitHub. + +As usual, we want to thank all the POT contributors (now 37 people have +contributed to the toolbox). But for this release we thank in particular +Nathan Cassereau and Kamel Guerda from the AI support team at +`IDRIS `__ for their support to the development of +the backend and OpenMP implementations. + +New features +^^^^^^^^^^^^ + +- OpenMP support for exact OT solvers (PR #260) +- Backend for running POT in numpy/torch + exact solver (PR #249) +- Backend implementation of most functions in ``ot.bregman`` (PR #280) +- Backend implementation of most functions in ``ot.optim`` (PR #282) +- Backend implementation of most functions in ``ot.gromov`` (PR #294, + PR #302) +- Test for arrays of different type and device (CPU/GPU) (PR #304, + #303) +- Implementation of Sinkhorn in log space with + ``method='sinkhorn_log'`` (PR #290) +- Implementation of regularization path for L2 Unbalanced OT (PR #274) +- Implementation of Projection Robust Wasserstein (PR #267) +- Implementation of Debiased Sinkhorn Barycenters (PR #291) +- Implementation of Sampled Gromov Wasserstein and Pointwise Gromov + Wasserstein (PR #275) +- Add ``pyproject.toml`` and build POT without installing cython first + (PR #293) +- Lazy implementation in log space for sinkhorn on samples (PR #259) +- Documentation cleanup (PR #298) +- Two up-to-date documentations `for stable + release `__ and for `master + branch `__. +- Building wheels on ARM for Raspberry PI and smartphones (PR #238) +- Update build wheels to new version and new pythons (PR #236, #253) +- Implementation of sliced Wasserstein distance (Issue #202, PR #203) +- Add minimal build to CI and perform pep8 test separately (PR #210) +- Speedup of tests and return run time (PR #262) +- Add "Why OT" discussion to the documentation (PR #220) +- New introductory example to discrete OT in the documentation (PR + #191) +- Add templates for Issues/PR on Github (PR#181) + +Closed issues +^^^^^^^^^^^^^ + +- Debug Memory leak in GAN example (#254) +- DEbug GPU bug (Issue #284, #287, PR #288) +- set\_gradients method for JAX backend (PR #278) +- Quicker GAN example for CircleCI build (PR #258) +- Better formatting in Readme (PR #234) +- Debug CI tests (PR #240, #241, #242) +- Bug in Partial OT solver dummy points (PR #215) +- Bug when Armijo linesearch (Issue #184, #198, #281, PR #189, #199, + #286) +- Bug Barycenter Sinkhorn (Issue 134, PR #195) +- Infeasible solution in exact OT (Issues #126,#93, PR #217) +- Doc for SUpport Barycenters (Issue #200, PR #201) +- Fix labels transport in BaseTransport (Issue #207, PR #208) +- Bug in ``emd_1d``, non respected bounds (Issue #169, PR #170) +- Removed Python 2.7 support and update codecov file (PR #178) +- Add normalization for WDA and test it (PR #172, #296) +- Cleanup code for new version of ``flake8`` (PR #176) +- Fixed requirements in ``setup.py`` (PR #174) +- Removed specific MacOS flags (PR #175) + 0.7.0 ----- @@ -50,7 +176,7 @@ problems. This release is also the moment to thank all the POT contributors (old and new) for helping making POT such a nice toolbox. A lot of changes -(also in the API) are comming for the next versions. +(also in the API) are coming for the next versions. Features ^^^^^^^^ @@ -72,6 +198,8 @@ Features Closed issues ^^^^^^^^^^^^^ +- Add JMLR paper to teh readme ad Mathieu Blondel to the Acknoledgments + (PR #231, #232) - Bug in Unbalanced OT example (Issue #127) - Clean Cython output when calling setup.py clean (Issue #122) - Various Macosx compilation problems (Issue #113, Issue #118, PR#130) @@ -103,8 +231,8 @@ mathematical problems and research but with the new contributions we now implement algorithms and solvers from 24 scientific papers (listed in the README.md file). New features include a direct implementation of the `empirical Sinkhorn -divergence `__ -, a new efficient (Cython implementation) solver for `EMD in +divergence `__, +a new efficient (Cython implementation) solver for `EMD in 1D `__ and corresponding `Wasserstein 1D `__. diff --git a/examples/backends/plot_optim_gromov_pytorch.py b/examples/backends/plot_optim_gromov_pytorch.py index 465f612..969707f 100644 --- a/examples/backends/plot_optim_gromov_pytorch.py +++ b/examples/backends/plot_optim_gromov_pytorch.py @@ -94,7 +94,7 @@ pl.axis("off") # %% -# Optimizing the weights of a simple template C0=eye(3) to fit Graph 1 +# Optimizing GW w.r.t. the weights on a template structure # ------------------------------------------------ # The adajacency matrix C1 is block diagonal with 3 blocks. We want to # optimize the weights of a simple template C0=eye(3) and see if we can diff --git a/ot/__init__.py b/ot/__init__.py index 4292b41..b6dc2b4 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -5,7 +5,8 @@ :py:mod:`ot.lp`, :py:mod:`ot.bregman`, :py:mod:`ot.optim` :py:mod:`ot.utils`, :py:mod:`ot.datasets`, :py:mod:`ot.gromov`, :py:mod:`ot.smooth` - :py:mod:`ot.stochastic` + :py:mod:`ot.stochastic`, :py:mod:`ot.partial`, :py:mod:`ot.regpath` + , :py:mod:`ot.unbalanced`. The following sub-modules are not imported due to additional dependencies: @@ -49,7 +50,7 @@ from .gromov import (gromov_wasserstein, gromov_wasserstein2, # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.8.0dev" +__version__ = "0.8.0" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', @@ -57,5 +58,6 @@ __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', - 'max_sliced_wasserstein_distance', + 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', + 'max_sliced_wasserstein_distance', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath'] diff --git a/ot/gpu/__init__.py b/ot/gpu/__init__.py index e939610..12db605 100644 --- a/ot/gpu/__init__.py +++ b/ot/gpu/__init__.py @@ -7,7 +7,13 @@ The GPU backend in handled by `cupy `_. .. warning:: - Note that by default the module is not import in :mod:`ot`. In order to + This module is now deprecated and will be removed in future releases. POT + now privides a backend mechanism that allows for solving prolem on GPU wth + the pytorch backend. + + +.. warning:: + Note that by default the module is not imported in :mod:`ot`. In order to use it you need to explicitely import :mod:`ot.gpu` . By default, the functions in this module accept and return numpy arrays @@ -36,7 +42,7 @@ from . import utils from .utils import dist, to_gpu, to_np -warnings.warn('This module will be deprecated in the next minor release of POT', category=DeprecationWarning) +warnings.warn('This module is deprecated and will be removed in the next minor release of POT', category=DeprecationWarning) __all__ = ["utils", "dist", "sinkhorn", diff --git a/ot/helpers/__init__.py b/ot/helpers/__init__.py new file mode 100644 index 0000000..b948671 --- /dev/null +++ b/ot/helpers/__init__.py @@ -0,0 +1,3 @@ +# Author: Remi Flamary +# +# License: MIT License -- cgit v1.2.3