summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHicham Janati <hicham.janati100@gmail.com>2021-11-03 08:41:35 +0100
committerGitHub <noreply@github.com>2021-11-03 08:41:35 +0100
commite1b67c641da3b3e497db6811af2c200022b10302 (patch)
tree44d42e1ae50d653bb07dd6ef9c1de14f71b21642
parent61340d526702616ff000d9e1cf71f52dd199a103 (diff)
[WIP] Add debiased barycenter (Sinkhorn + convolutional sinkhorn) (#291)
* add debiased sinkhorn barycenter + make loops pythonic * add debiased arg in tests * add 1d and 2d examples of debiased barycenters * fix doctest * fix flake8 * pep8 + make func private + add convergence warnings * remove rel paths + add rng + pylab to pyplot * fix stopping criterion debiased * pass alex * change params with new API * add logdomain barycenters + separate debiased API * test new API * fix jax read-only ? * raise error for jax * test catch jax error * fix pytest catch error * fix relative path * fix flake8 * add warn arg everywhere * fix ref number * catch warnings in tests * add contrib to readme + change ref number * fix convolution example + gallery thumbnails * increase coverage * fix flake Co-authored-by: Hicham Janati <hicham.janati@inria.fr> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
-rw-r--r--README.md8
-rw-r--r--examples/barycenters/plot_barycenter_1D.py63
-rw-r--r--examples/barycenters/plot_barycenter_lp_vs_entropic.py2
-rw-r--r--examples/barycenters/plot_convolutional_barycenter.py53
-rw-r--r--examples/barycenters/plot_debiased_barycenter.py131
-rw-r--r--examples/domain-adaptation/plot_otda_color_images.py118
-rw-r--r--examples/domain-adaptation/plot_otda_linear_mapping.py73
-rw-r--r--examples/domain-adaptation/plot_otda_mapping_colors_images.py118
-rwxr-xr-xexamples/gromov/plot_gromov_barycenter.py90
-rw-r--r--ot/bregman.py1491
-rw-r--r--test/test_bregman.py365
11 files changed, 1837 insertions, 675 deletions
diff --git a/README.md b/README.md
index cfb9744..ff32c53 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,8 @@ POT provides the following generic OT solvers (links to examples):
* [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7].
* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html).
* Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4].
-* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
+* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
+* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12])
@@ -188,7 +189,7 @@ The contributors to this library are
* [Kilian Fatras](https://kilianfatras.github.io/) (Stochastic solvers)
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
* [Vayer Titouan](https://tvayer.github.io/) (Gromov-Wasserstein -, Fused-Gromov-Wasserstein)
-* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)
+* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT, Debiased barycenters)
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
@@ -293,3 +294,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
Machine Learning (pp. 4104-4113). PMLR.
+
+[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
+Conference on Machine Learning, PMLR 119:4692-4701, 2020 \ No newline at end of file
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 63dc460..2373e99 100644
--- a/examples/barycenters/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -18,10 +18,10 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
#
# License: MIT License
-# sphinx_gallery_thumbnail_number = 4
+# sphinx_gallery_thumbnail_number = 1
import numpy as np
-import matplotlib.pylab as pl
+import matplotlib.pyplot as plt
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D # noqa
@@ -51,18 +51,6 @@ M = ot.utils.dist0(n)
M /= M.max()
##############################################################################
-# Plot data
-# ---------
-
-#%% plot the distributions
-
-pl.figure(1, figsize=(6.4, 3))
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
-pl.tight_layout()
-
-##############################################################################
# Barycenter computation
# ----------------------
@@ -78,24 +66,20 @@ bary_l2 = A.dot(weights)
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
-pl.figure(2)
-pl.clf()
-pl.subplot(2, 1, 1)
-for i in range(n_distributions):
- pl.plot(x, A[:, i])
-pl.title('Distributions')
+f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)
+ax1.plot(x, A, color="black")
+ax1.set_title('Distributions')
-pl.subplot(2, 1, 2)
-pl.plot(x, bary_l2, 'r', label='l2')
-pl.plot(x, bary_wass, 'g', label='Wasserstein')
-pl.legend()
-pl.title('Barycenters')
-pl.tight_layout()
+ax2.plot(x, bary_l2, 'r', label='l2')
+ax2.plot(x, bary_wass, 'g', label='Wasserstein')
+ax2.set_title('Barycenters')
+
+plt.legend()
+plt.show()
##############################################################################
# Barycentric interpolation
# -------------------------
-
#%% barycenter interpolation
n_alpha = 11
@@ -106,24 +90,23 @@ B_l2 = np.zeros((n, n_alpha))
B_wass = np.copy(B_l2)
-for i in range(0, n_alpha):
+for i in range(n_alpha):
alpha = alpha_list[i]
weights = np.array([1 - alpha, alpha])
B_l2[:, i] = A.dot(weights)
B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
#%% plot interpolation
+plt.figure(2)
-pl.figure(3)
-
-cmap = pl.cm.get_cmap('viridis')
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -134,18 +117,18 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with l2')
-pl.tight_layout()
+plt.title('Barycenter interpolation with l2')
+plt.tight_layout()
-pl.figure(4)
-cmap = pl.cm.get_cmap('viridis')
+plt.figure(3)
+cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = plt.gcf().gca(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -156,7 +139,7 @@ ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
-pl.title('Barycenter interpolation with Wasserstein')
-pl.tight_layout()
+plt.title('Barycenter interpolation with Wasserstein')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/barycenters/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
index 57a6bac..6502f16 100644
--- a/examples/barycenters/plot_barycenter_lp_vs_entropic.py
+++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
=================================================================================
-1D Wasserstein barycenter comparison between exact LP and entropic regularization
+1D Wasserstein barycenter: exact LP vs entropic regularization
=================================================================================
This example illustrates the computation of regularized Wasserstein Barycenter
diff --git a/examples/barycenters/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py
index cbcd4a1..3721f31 100644
--- a/examples/barycenters/plot_convolutional_barycenter.py
+++ b/examples/barycenters/plot_convolutional_barycenter.py
@@ -6,17 +6,18 @@
Convolutional Wasserstein Barycenter example
============================================
-This example is designed to illustrate how the Convolutional Wasserstein Barycenter
-function of POT works.
+This example is designed to illustrate how the Convolutional Wasserstein
+Barycenter function of POT works.
"""
# Author: Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
-
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+import matplotlib.pyplot as plt
import ot
##############################################################################
@@ -25,22 +26,19 @@ import ot
#
# The four distributions are constructed from 4 simple images
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
-f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2]
-f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2]
-f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2]
-f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2]
+f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
+f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
-A = []
f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
f3 = f3 / np.sum(f3)
f4 = f4 / np.sum(f4)
-A.append(f1)
-A.append(f2)
-A.append(f3)
-A.append(f4)
-A = np.array(A)
+A = np.array([f1, f2, f3, f4])
nb_images = 5
@@ -57,14 +55,13 @@ v4 = np.array((0, 0, 0, 1))
# ----------------------------------------
#
-pl.figure(figsize=(10, 10))
-pl.title('Convolutional Wasserstein Barycenters in POT')
+fig, axes = plt.subplots(nb_images, nb_images, figsize=(7, 7))
+plt.suptitle('Convolutional Wasserstein Barycenters in POT')
cm = 'Blues'
# regularization parameter
reg = 0.004
for i in range(nb_images):
for j in range(nb_images):
- pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
tx = float(i) / (nb_images - 1)
ty = float(j) / (nb_images - 1)
@@ -74,19 +71,19 @@ for i in range(nb_images):
weights = (1 - ty) * tmp1 + ty * tmp2
if i == 0 and j == 0:
- pl.imshow(f1, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f1, cmap=cm)
elif i == 0 and j == (nb_images - 1):
- pl.imshow(f3, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f3, cmap=cm)
elif i == (nb_images - 1) and j == 0:
- pl.imshow(f2, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f2, cmap=cm)
elif i == (nb_images - 1) and j == (nb_images - 1):
- pl.imshow(f4, cmap=cm)
- pl.axis('off')
+ axes[i, j].imshow(f4, cmap=cm)
else:
# call to barycenter computation
- pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
- pl.axis('off')
-pl.show()
+ axes[i, j].imshow(
+ ot.bregman.convolutional_barycenter2d(A, reg, weights),
+ cmap=cm
+ )
+ axes[i, j].axis('off')
+plt.tight_layout()
+plt.show()
diff --git a/examples/barycenters/plot_debiased_barycenter.py b/examples/barycenters/plot_debiased_barycenter.py
new file mode 100644
index 0000000..2a603dd
--- /dev/null
+++ b/examples/barycenters/plot_debiased_barycenter.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Debiased Sinkhorn barycenter demo
+=================================
+
+This example illustrates the computation of the debiased Sinkhorn barycenter
+as proposed in [37]_.
+
+
+.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
+ International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+"""
+
+# Author: Hicham Janati <hicham.janati100@gmail.com>
+#
+# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
+import os
+from pathlib import Path
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+import ot
+from ot.bregman import (barycenter, barycenter_debiased,
+ convolutional_barycenter2d,
+ convolutional_barycenter2d_debiased)
+
+##############################################################################
+# Debiased barycenter of 1D Gaussians
+# ------------------------------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+#%% barycenter computation
+
+alpha = 0.2 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+epsilons = [5e-3, 1e-2, 5e-2]
+
+
+bars = [barycenter(A, M, reg, weights) for reg in epsilons]
+bars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons]
+labels = ["Sinkhorn barycenter", "Debiased barycenter"]
+colors = ["indianred", "gold"]
+
+f, axes = plt.subplots(1, len(epsilons), tight_layout=True, sharey=True,
+ figsize=(12, 4), num=1)
+for ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased):
+ ax.plot(A[:, 0], color="k", ls="--", label="Input data", alpha=0.3)
+ ax.plot(A[:, 1], color="k", ls="--", alpha=0.3)
+ for data, label, color in zip([bar, bar_debiased], labels, colors):
+ ax.plot(data, color=color, label=label, lw=2)
+ ax.set_title(r"$\varepsilon = %.3f$" % eps)
+plt.legend()
+plt.show()
+
+
+##############################################################################
+# Debiased barycenter of 2D images
+# ---------------------------------
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+f1 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
+f2 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
+
+A = np.asarray([f1, f2]) + 1e-2
+A /= A.sum(axis=(1, 2))[:, None, None]
+
+##############################################################################
+# Display the input images
+
+fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2)
+for ax, img in zip(axes, A):
+ ax.imshow(img, cmap="Greys")
+ ax.axis("off")
+fig.tight_layout()
+plt.show()
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+bars_sinkhorn, bars_debiased = [], []
+epsilons = [5e-3, 7e-3, 1e-2]
+for eps in epsilons:
+ bar = convolutional_barycenter2d(A, eps)
+ bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True)
+ bars_sinkhorn.append(bar)
+ bars_debiased.append(bar_debiased)
+
+titles = ["Sinkhorn", "Debiased"]
+all_bars = [bars_sinkhorn, bars_debiased]
+fig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3)
+for jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)):
+ for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)):
+ ax.imshow(img, cmap="Greys")
+ if jj == 0:
+ ax.set_title(r"$\varepsilon = %.3f$" % eps, fontsize=13)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.spines['top'].set_visible(False)
+ ax.spines['right'].set_visible(False)
+ ax.spines['bottom'].set_visible(False)
+ ax.spines['left'].set_visible(False)
+ if ii == 0:
+ ax.set_ylabel(method, fontsize=15)
+fig.tight_layout()
+plt.show()
diff --git a/examples/domain-adaptation/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py
index 6218b13..06dc8ab 100644
--- a/examples/domain-adaptation/plot_otda_color_images.py
+++ b/examples/domain-adaptation/plot_otda_color_images.py
@@ -19,12 +19,15 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
# sphinx_gallery_thumbnail_number = 2
+import os
+from pathlib import Path
+
import numpy as np
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
import ot
-r = np.random.RandomState(42)
+rng = np.random.RandomState(42)
def im2mat(img):
@@ -46,16 +49,19 @@ def minmax(img):
# -------------
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
nb = 500
-idx1 = r.randint(X1.shape[0], size=(nb,))
-idx2 = r.randint(X2.shape[0], size=(nb,))
+idx1 = rng.randint(X1.shape[0], size=(nb,))
+idx2 = rng.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
@@ -65,39 +71,39 @@ Xt = X2[idx2, :]
# Plot original image
# -------------------
-pl.figure(1, figsize=(6.4, 3))
+plt.figure(1, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
+plt.subplot(1, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
##############################################################################
# Scatter plot of colors
# ----------------------
-pl.figure(2, figsize=(6.4, 3))
+plt.figure(2, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
@@ -130,37 +136,37 @@ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
# Plot new images
# ---------------
-pl.figure(3, figsize=(8, 4))
+plt.figure(3, figsize=(8, 4))
-pl.subplot(2, 3, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.subplot(2, 3, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(2, 3, 2)
-pl.imshow(I1t)
-pl.axis('off')
-pl.title('Image 1 Adapt')
+plt.subplot(2, 3, 2)
+plt.imshow(I1t)
+plt.axis('off')
+plt.title('Image 1 Adapt')
-pl.subplot(2, 3, 3)
-pl.imshow(I1te)
-pl.axis('off')
-pl.title('Image 1 Adapt (reg)')
+plt.subplot(2, 3, 3)
+plt.imshow(I1te)
+plt.axis('off')
+plt.title('Image 1 Adapt (reg)')
-pl.subplot(2, 3, 4)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
+plt.subplot(2, 3, 4)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
-pl.subplot(2, 3, 5)
-pl.imshow(I2t)
-pl.axis('off')
-pl.title('Image 2 Adapt')
+plt.subplot(2, 3, 5)
+plt.imshow(I2t)
+plt.axis('off')
+plt.title('Image 2 Adapt')
-pl.subplot(2, 3, 6)
-pl.imshow(I2te)
-pl.axis('off')
-pl.title('Image 2 Adapt (reg)')
-pl.tight_layout()
+plt.subplot(2, 3, 6)
+plt.imshow(I2te)
+plt.axis('off')
+plt.title('Image 2 Adapt (reg)')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/domain-adaptation/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py
index be47510..a44096a 100644
--- a/examples/domain-adaptation/plot_otda_linear_mapping.py
+++ b/examples/domain-adaptation/plot_otda_linear_mapping.py
@@ -13,9 +13,11 @@ Linear OT mapping estimation
# License: MIT License
# sphinx_gallery_thumbnail_number = 2
+import os
+from pathlib import Path
import numpy as np
-import pylab as pl
+from matplotlib import pyplot as plt
import ot
##############################################################################
@@ -26,17 +28,19 @@ n = 1000
d = 2
sigma = .1
+rng = np.random.RandomState(42)
+
# source samples
-angles = np.random.rand(n, 1) * 2 * np.pi
+angles = rng.rand(n, 1) * 2 * np.pi
xs = np.concatenate((np.sin(angles), np.cos(angles)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xs[:n // 2, 1] += 2
# target samples
-anglet = np.random.rand(n, 1) * 2 * np.pi
+anglet = rng.rand(n, 1) * 2 * np.pi
xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
- axis=1) + sigma * np.random.randn(n, 2)
+ axis=1) + sigma * rng.randn(n, 2)
xt[:n // 2, 1] += 2
@@ -48,9 +52,9 @@ xt = xt.dot(A) + b
# Plot data
# ---------
-pl.figure(1, (5, 5))
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
+plt.figure(1, (5, 5))
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
##############################################################################
@@ -66,13 +70,13 @@ xst = xs.dot(Ae) + be
# Plot transported samples
# ------------------------
-pl.figure(1, (5, 5))
-pl.clf()
-pl.plot(xs[:, 0], xs[:, 1], '+')
-pl.plot(xt[:, 0], xt[:, 1], 'o')
-pl.plot(xst[:, 0], xst[:, 1], '+')
+plt.figure(1, (5, 5))
+plt.clf()
+plt.plot(xs[:, 0], xs[:, 1], '+')
+plt.plot(xt[:, 0], xt[:, 1], 'o')
+plt.plot(xst[:, 0], xst[:, 1], '+')
-pl.show()
+plt.show()
##############################################################################
# Load image data
@@ -94,8 +98,11 @@ def minmax(img):
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
@@ -123,24 +130,24 @@ I2t = minmax(mat2im(xts, I2.shape))
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 7))
+plt.figure(2, figsize=(10, 7))
-pl.subplot(2, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 2, 3)
-pl.imshow(I1t)
-pl.axis('off')
-pl.title('Mapping Im. 1')
+plt.subplot(2, 2, 3)
+plt.imshow(I1t)
+plt.axis('off')
+plt.title('Mapping Im. 1')
-pl.subplot(2, 2, 4)
-pl.imshow(I2t)
-pl.axis('off')
-pl.title('Inverse mapping Im. 2')
+plt.subplot(2, 2, 4)
+plt.imshow(I2t)
+plt.axis('off')
+plt.title('Inverse mapping Im. 2')
diff --git a/examples/domain-adaptation/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
index 72010a6..dbece70 100644
--- a/examples/domain-adaptation/plot_otda_mapping_colors_images.py
+++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
@@ -21,12 +21,14 @@ discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
# License: MIT License
# sphinx_gallery_thumbnail_number = 3
+import os
+from pathlib import Path
import numpy as np
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
import ot
-r = np.random.RandomState(42)
+rng = np.random.RandomState(42)
def im2mat(img):
@@ -48,17 +50,19 @@ def minmax(img):
# -------------
# Loading images
-I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+I1 = plt.imread(os.path.join(data_path, 'ocean_day.jpg')).astype(np.float64) / 256
+I2 = plt.imread(os.path.join(data_path, 'ocean_sunset.jpg')).astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
# training samples
nb = 500
-idx1 = r.randint(X1.shape[0], size=(nb,))
-idx2 = r.randint(X2.shape[0], size=(nb,))
+idx1 = rng.randint(X1.shape[0], size=(nb,))
+idx2 = rng.randint(X2.shape[0], size=(nb,))
Xs = X1[idx1, :]
Xt = X2[idx2, :]
@@ -99,76 +103,76 @@ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
# Plot original images
# --------------------
-pl.figure(1, figsize=(6.4, 3))
-pl.subplot(1, 2, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Image 1')
+plt.figure(1, figsize=(6.4, 3))
+plt.subplot(1, 2, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot pixel values distribution
# ------------------------------
-pl.figure(2, figsize=(6.4, 5))
+plt.figure(2, figsize=(6.4, 5))
-pl.subplot(1, 2, 1)
-pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
+plt.subplot(1, 2, 1)
+plt.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 1')
-pl.subplot(1, 2, 2)
-pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
-pl.axis([0, 1, 0, 1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-pl.tight_layout()
+plt.subplot(1, 2, 2)
+plt.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+plt.axis([0, 1, 0, 1])
+plt.xlabel('Red')
+plt.ylabel('Blue')
+plt.title('Image 2')
+plt.tight_layout()
##############################################################################
# Plot transformed images
# -----------------------
-pl.figure(2, figsize=(10, 5))
+plt.figure(2, figsize=(10, 5))
-pl.subplot(2, 3, 1)
-pl.imshow(I1)
-pl.axis('off')
-pl.title('Im. 1')
+plt.subplot(2, 3, 1)
+plt.imshow(I1)
+plt.axis('off')
+plt.title('Im. 1')
-pl.subplot(2, 3, 4)
-pl.imshow(I2)
-pl.axis('off')
-pl.title('Im. 2')
+plt.subplot(2, 3, 4)
+plt.imshow(I2)
+plt.axis('off')
+plt.title('Im. 2')
-pl.subplot(2, 3, 2)
-pl.imshow(Image_emd)
-pl.axis('off')
-pl.title('EmdTransport')
+plt.subplot(2, 3, 2)
+plt.imshow(Image_emd)
+plt.axis('off')
+plt.title('EmdTransport')
-pl.subplot(2, 3, 5)
-pl.imshow(Image_sinkhorn)
-pl.axis('off')
-pl.title('SinkhornTransport')
+plt.subplot(2, 3, 5)
+plt.imshow(Image_sinkhorn)
+plt.axis('off')
+plt.title('SinkhornTransport')
-pl.subplot(2, 3, 3)
-pl.imshow(Image_mapping_linear)
-pl.axis('off')
-pl.title('MappingTransport (linear)')
+plt.subplot(2, 3, 3)
+plt.imshow(Image_mapping_linear)
+plt.axis('off')
+plt.title('MappingTransport (linear)')
-pl.subplot(2, 3, 6)
-pl.imshow(Image_mapping_gaussian)
-pl.axis('off')
-pl.title('MappingTransport (gaussian)')
-pl.tight_layout()
+plt.subplot(2, 3, 6)
+plt.imshow(Image_mapping_gaussian)
+plt.axis('off')
+plt.title('MappingTransport (gaussian)')
+plt.tight_layout()
-pl.show()
+plt.show()
diff --git a/examples/gromov/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py
index e2d88ba..7fe081f 100755
--- a/examples/gromov/plot_gromov_barycenter.py
+++ b/examples/gromov/plot_gromov_barycenter.py
@@ -13,11 +13,13 @@ computation in POT.
#
# License: MIT License
+import os
+from pathlib import Path
import numpy as np
import scipy as sp
-import matplotlib.pylab as pl
+from matplotlib import pyplot as plt
from sklearn import manifold
from sklearn.decomposition import PCA
@@ -89,17 +91,19 @@ def im2mat(img):
return img.reshape((img.shape[0] * img.shape[1], img.shape[2]))
-square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2]
-cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2]
-triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2]
-star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2]
+this_file = os.path.realpath('__file__')
+data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
+
+square = plt.imread(os.path.join(data_path, 'square.png')).astype(np.float64)[:, :, 2]
+cross = plt.imread(os.path.join(data_path, 'cross.png')).astype(np.float64)[:, :, 2]
+triangle = plt.imread(os.path.join(data_path, 'triangle.png')).astype(np.float64)[:, :, 2]
+star = plt.imread(os.path.join(data_path, 'star.png')).astype(np.float64)[:, :, 2]
shapes = [square, cross, triangle, star]
S = 4
xs = [[] for i in range(S)]
-
for nb in range(4):
for i in range(8):
for j in range(8):
@@ -184,64 +188,64 @@ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
-fig = pl.figure(figsize=(10, 10))
+fig = plt.figure(figsize=(10, 10))
-ax1 = pl.subplot2grid((4, 4), (0, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax1 = plt.subplot2grid((4, 4), (0, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
-ax2 = pl.subplot2grid((4, 4), (0, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax2 = plt.subplot2grid((4, 4), (0, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
-ax3 = pl.subplot2grid((4, 4), (0, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax3 = plt.subplot2grid((4, 4), (0, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
-ax4 = pl.subplot2grid((4, 4), (0, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax4 = plt.subplot2grid((4, 4), (0, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
-ax5 = pl.subplot2grid((4, 4), (1, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax5 = plt.subplot2grid((4, 4), (1, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
-ax6 = pl.subplot2grid((4, 4), (1, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax6 = plt.subplot2grid((4, 4), (1, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
-ax7 = pl.subplot2grid((4, 4), (2, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax7 = plt.subplot2grid((4, 4), (2, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
-ax8 = pl.subplot2grid((4, 4), (2, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax8 = plt.subplot2grid((4, 4), (2, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
-ax9 = pl.subplot2grid((4, 4), (3, 0))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax9 = plt.subplot2grid((4, 4), (3, 0))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
-ax10 = pl.subplot2grid((4, 4), (3, 1))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax10 = plt.subplot2grid((4, 4), (3, 1))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
-ax11 = pl.subplot2grid((4, 4), (3, 2))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax11 = plt.subplot2grid((4, 4), (3, 2))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
-ax12 = pl.subplot2grid((4, 4), (3, 3))
-pl.xlim((-1, 1))
-pl.ylim((-1, 1))
+ax12 = plt.subplot2grid((4, 4), (3, 3))
+plt.xlim((-1, 1))
+plt.ylim((-1, 1))
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
diff --git a/ot/bregman.py b/ot/bregman.py
index 0499b8e..786f151 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,7 +7,7 @@ Bregman projections solvers for entropic regularized OT
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
-# Hicham Janati <hicham.janati@inria.fr>
+# Hicham Janati <hicham.janati100@gmail.com>
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
# Alexander Tong <alexander.tong@yale.edu>
# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
@@ -25,7 +25,8 @@ from .backend import get_backend
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=True,
+ **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -43,8 +44,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
@@ -77,7 +80,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
@@ -94,6 +98,8 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -117,13 +123,21 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
- .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
+ A., & Peyré, G. (2019, April). Interpolating between optimal transport
+ and MMD using Sinkhorn divergences. In The 22nd International Conference
+ on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
@@ -131,37 +145,44 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn>`
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
+ :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
+ ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling
+ :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
"""
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
**kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log)
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn,
+ **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
return sinkhorn_epsilon_scaling(a, b, M, reg,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn,
+ **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+ stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
@@ -179,13 +200,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>`
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>`
**Choosing a Sinkhorn solver**
@@ -212,7 +236,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
@@ -228,6 +253,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -252,19 +279,27 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
.. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
- algorithms for optimal transport via Sinkhorn iteration, Advances in Neural
- Information Processing Systems (NIPS) 31, 2017
-
- .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
-
+ algorithms for optimal transport via Sinkhorn iteration,
+ Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
+ Trouvé, A., & Peyré, G. (2019, April).
+ Interpolating between optimal transport and MMD using Sinkhorn
+ divergences. In The 22nd International Conference on Artificial
+ Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
@@ -272,7 +307,8 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn2>`
ot.bregman.greenkhorn : Greenkhorn :ref:`[21] <references-sinkhorn2>`
- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] <references-sinkhorn2>` :ref:`[10] <references-sinkhorn2>`
+ ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn :ref:`[9] <references-sinkhorn2>`
+ :ref:`[10] <references-sinkhorn2>`
"""
@@ -317,8 +353,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
raise ValueError("Unknown method '%s'." % method)
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
+ verbose=False, log=False, warn=True,
+ **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -335,10 +372,13 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>`
+ The algorithm used for solving the problem is the Sinkhorn-Knopp
+ matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>`
Parameters
@@ -347,7 +387,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
@@ -360,6 +401,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -384,7 +427,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
See Also
@@ -427,9 +472,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
K = nx.exp(M / (-reg))
Kp = (1 / a).reshape(-1, 1) * K
- cpt = 0
+
err = 1
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
uprev = u
vprev = v
KtransposeU = nx.dot(K.T, u)
@@ -441,11 +486,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
+ warnings.warn('Warning: numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
- if cpt % 10 == 0:
+ if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
@@ -457,13 +502,20 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
+ log['niter'] = ii
log['u'] = u
log['v'] = v
@@ -482,8 +534,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def sinkhorn_log(a, b, M, reg, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
+ log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization optimal transport problem in log space
and return the OT matrix
@@ -528,6 +580,8 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -552,9 +606,15 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
- .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
+ .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
+ Trouvé, A., & Peyré, G. (2019, April). Interpolating between
+ optimal transport and MMD using Sinkhorn divergences. In The
+ 22nd International Conference on Artificial Intelligence and
+ Statistics (pp. 2681-2690). PMLR.
See Also
@@ -613,7 +673,7 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000,
if log:
log = {'err': []}
- Mr = M / (-reg)
+ Mr = - M / reg
# we assume that no distances are null except those of the diagonal of
# distances
@@ -630,14 +690,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000,
loga = nx.log(a)
logb = nx.log(b)
- cpt = 0
err = 1
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
v = logb - nx.logsumexp(Mr + u[:, None], 0)
u = loga - nx.logsumexp(Mr + v[None, :], 1)
- if cpt % 10 == 0:
+ if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
@@ -648,13 +707,20 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000,
log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
- cpt = cpt + 1
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
+ log['niter'] = ii
log['log_u'] = u
log['log_v'] = v
log['u'] = nx.exp(u)
@@ -667,11 +733,13 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000,
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
- log=False):
+ log=False, warn=True):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
- The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>` which is a stochastic version of the Sinkhorn-Knopp algorithm :ref:`[2] <references-greenkhorn>`
+ The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>`
+ which is a stochastic version of the Sinkhorn-Knopp
+ algorithm :ref:`[2] <references-greenkhorn>`
The function solves the following optimization problem:
@@ -686,8 +754,10 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
Parameters
@@ -696,7 +766,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
- and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
+ and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
+ (return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
@@ -707,6 +778,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
Stop threshold on error (>0)
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -731,9 +804,14 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
- .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
+ .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time
+ approximation algorithms for optimal transport via Sinkhorn
+ iteration, Advances in Neural Information Processing
+ Systems (NIPS) 31, 2017
See Also
@@ -747,7 +825,8 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
nx = get_backend(M, a, b)
if nx.__name__ == "jax":
- raise TypeError("JAX arrays have been received. Greenkhorn is not compatible with JAX")
+ raise TypeError("JAX arrays have been received. Greenkhorn is not "
+ "compatible with JAX")
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
@@ -771,7 +850,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
log['u'] = u
log['v'] = v
- for i in range(numItermax):
+ for ii in range(numItermax):
i_1 = nx.argmax(nx.abs(viol))
i_2 = nx.argmax(nx.abs(viol_2))
m_viol_1 = nx.abs(viol[i_1])
@@ -795,14 +874,17 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
viol += (-old_v + new_v) * K[:, i_2] * u
viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2]
v[i_2] = new_v
- # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))
if stopThr_val <= stopThr:
break
else:
- print('Warning: Algorithm did not converge')
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
+ log["n_iter"] = ii
log['u'] = u
log['v'] = v
@@ -814,7 +896,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20,
- log=False, **kwargs):
+ log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization
@@ -831,13 +913,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
+ weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>` but with the log stabilization
- proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in :ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) .
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>`
+ but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in
+ :ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) .
Parameters
@@ -851,7 +937,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
reg : float
Regularization term >0
tau : float
- threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
+ for log scaling
warmstart : table of vectors
if given then starting values for alpha and beta log scalings
numItermax : int, optional
@@ -862,6 +949,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -886,11 +975,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
+ Optimal Transport, Advances in Neural Information Processing
+ Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
+ for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
See Also
@@ -920,7 +1015,6 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
dim_a = len(a)
dim_b = len(b)
- cpt = 0
if log:
log = {'err': []}
@@ -935,7 +1029,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
- u, v = nx.ones(dim_a, type_as=M) / dim_a, nx.ones(dim_b, type_as=M) / dim_b
+ u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M)
+ u /= dim_a
+ v /= dim_b
def get_K(alpha, beta):
"""log space computation"""
@@ -947,21 +1043,17 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
/ reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b))))
- # print(np.min(K))
-
K = get_K(alpha, beta)
transp = K
- loop = 1
- cpt = 0
err = 1
- while loop:
+ for ii in range(numItermax):
uprev = u
vprev = v
# sinkhorn update
- v = b / (nx.dot(K.T, u) + 1e-16)
- u = a / (nx.dot(K, v) + 1e-16)
+ v = b / (nx.dot(K.T, u))
+ u = a / (nx.dot(K, v))
# remove numerical problems and store them in K
if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
@@ -977,7 +1069,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
v = nx.ones(dim_b, type_as=M) / dim_b
K = get_K(alpha, beta)
- if cpt % print_period == 0:
+ if ii % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
@@ -993,33 +1085,33 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
log['err'].append(err)
if verbose:
- if cpt % (print_period * 20) == 0:
+ if ii % (print_period * 20) == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr:
- loop = False
-
- if cpt >= numItermax:
- loop = False
+ break
if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
+ warnings.warn('Numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
-
- cpt = cpt + 1
-
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
logu = alpha / reg + nx.log(u)
logv = beta / reg + nx.log(v)
+ log["n_iter"] = ii
log['logu'] = logu
log['logv'] = logv
log['alpha'] = alpha + reg * nx.log(u)
@@ -1048,13 +1140,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numInnerItermax=100, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=10,
- log=False, **kwargs):
+ log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
-
The function solves the following optimization problem:
-
.. math::
\gamma = \mathop{\arg \min}_\gamma <\gamma, \mathbf{M}>_F + \mathrm{reg}\cdot\Omega(\gamma)
@@ -1064,16 +1154,16 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
\gamma &\geq 0
where :
-
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
-
-
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights
+ (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
- scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>` but with the log stabilization
- proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2
-
+ scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>`
+ but with the log stabilization
+ proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling
+ proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2
Parameters
----------
@@ -1086,7 +1176,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
reg : float
Regularization term >0
tau : float
- threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` for log scaling
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}`
+ for log scaling
warmstart : tuple of vectors
if given then starting values for alpha and beta log scalings
numItermax : int, optional
@@ -1101,6 +1192,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -1108,10 +1201,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
-
Examples
--------
-
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
@@ -1123,19 +1214,19 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
.. _references-sinkhorn-epsilon-scaling:
References
----------
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
-
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
-
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
-
"""
a, b, M = list_to_array(a, b, M)
@@ -1155,7 +1246,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numItermin = 35
numItermax = max(numItermin, numItermax) # ensure that last velue is exact
- cpt = 0
+ ii = 0
if log:
log = {'err': []}
@@ -1170,12 +1261,10 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
def get_reg(n): # exponential decreasing
return (epsilon0 - reg) * np.exp(-n) + reg
- loop = 1
- cpt = 0
err = 1
- while loop:
+ for ii in range(numItermax):
- regi = get_reg(cpt)
+ regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
numItermax=numInnerItermax, stopThr=1e-9,
@@ -1185,10 +1274,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
alpha = logi['alpha']
beta = logi['beta']
- if cpt >= numItermax:
- loop = False
-
- if cpt % (print_period) == 0: # spsion nearly converged
+ if ii % (print_period) == 0: # spsion nearly converged
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
@@ -1197,19 +1283,22 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
log['err'].append(err)
if verbose:
- if cpt % (print_period * 10) == 0:
+ if ii % (print_period * 10) == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- if err <= stopThr and cpt > numItermin:
- loop = False
+ print('{:5d}|{:8e}|'.format(ii, err))
- cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
+ if err <= stopThr and ii > numItermin:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
log['alpha'] = alpha
log['beta'] = beta
log['warmstart'] = (log['alpha'], log['beta'])
+ log['niter'] = ii
return G, log
else:
return G
@@ -1245,7 +1334,7 @@ def projC(gamma, q):
def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
- stopThr=1e-4, verbose=False, log=False, **kwargs):
+ stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
The function solves the following optimization problem:
@@ -1255,11 +1344,16 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`OT_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`.
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] <references-barycenter>`
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
+ algorithm as proposed in :ref:`[3] <references-barycenter>`
Parameters
----------
@@ -1270,7 +1364,7 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
reg : float
Regularization term > 0
method : str (optional)
- method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
+ method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log'
weights : array-like, shape (n_hists,)
Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
@@ -1281,6 +1375,8 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
@@ -1295,7 +1391,9 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
@@ -1303,18 +1401,24 @@ def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
return barycenter_sinkhorn(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return barycenter_stabilized(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
- log=log, **kwargs)
+ log=log, warn=warn, **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _barycenter_sinkhorn_log(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
+ stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
The function solves the following optimization problem:
@@ -1324,11 +1428,15 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] <references-barycenter-sinkhorn>`
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
+ scaling algorithm as proposed in :ref:`[3]<references-barycenter-sinkhorn>`.
Parameters
----------
@@ -1348,6 +1456,8 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
@@ -1362,7 +1472,9 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
@@ -1378,43 +1490,109 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
if log:
log = {'err': []}
- # M = M/np.median(M) # suggested by G. Peyre
K = nx.exp(-M / reg)
- cpt = 0
err = 1
UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)
u = (geometricMean(UKv) / UKv.T).T
- while (err > stopThr and cpt < numItermax):
- cpt = cpt + 1
+ for ii in range(numItermax):
+
UKv = u * nx.dot(K, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv
- if cpt % 10 == 1:
+ if ii % 10 == 1:
err = nx.sum(nx.std(UKv, axis=1))
# log and verbose print
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
return geometricBar(weights, UKv), log
else:
return geometricBar(weights, UKv)
+def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the entropic wasserstein barycenter in log-domain
+ """
+
+ A, M = list_to_array(A, M)
+ dim, n_hists = A.shape
+
+ nx = get_backend(A, M)
+
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ if weights is None:
+ weights = nx.ones(n_hists, type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ M = - M / reg
+ logA = nx.log(A + 1e-15)
+ log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros(dim, type_as=A)
+ for k in range(n_hists):
+ f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
+ log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
+ log_bar = log_bar + weights[k] * log_KU[:, k]
+
+ if ii % 10 == 1:
+ err = nx.exp(G + log_KU).std(axis=1).sum()
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if err < stopThr:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+
+ G = log_bar[:, None] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
+ stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization.
The function solves the following optimization problem:
@@ -1424,11 +1602,15 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
- - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[3] <references-barycenter-stabilized>`
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
+ algorithm as proposed in :ref:`[3] <references-barycenter-stabilized>`
Parameters
----------
@@ -1439,7 +1621,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
reg : float
Regularization term > 0
tau : float
- threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` for log scaling
+ threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
+ for log scaling
weights : array-like, shape (n_hists,)
Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates)
numItermax : int, optional
@@ -1450,6 +1633,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
@@ -1464,7 +1649,9 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
References
----------
- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+ Iterative Bregman projections for regularized transportation problems.
+ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""
@@ -1486,19 +1673,18 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
K = nx.exp(-M / reg)
- cpt = 0
err = 1.
alpha = nx.zeros((dim,), type_as=M)
beta = nx.zeros((dim,), type_as=M)
q = nx.ones((dim,), type_as=M) / dim
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
qprev = q
Kv = nx.dot(K, v)
- u = A / (Kv + 1e-16)
+ u = A / Kv
Ktu = nx.dot(K.T, u)
q = geometricBar(weights, Ktu)
Q = q[:, None]
- v = Q / (Ktu + 1e-16)
+ v = Q / Ktu
absorbing = False
if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
@@ -1512,40 +1698,244 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration %s' % cpt)
+ warnings.warn('Numerical errors at iteration %s' % ii)
q = qprev
break
- if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ if (ii % 10 == 0 and not absorbing) or ii == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err = nx.max(nx.abs(u * Kv - A))
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 50 == 0:
+ if ii % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
- cpt += 1
- if err > stopThr:
- warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
- "Try a larger entropy `reg`" +
- "Or a larger absorption threshold `tau`.")
+ else:
+ if warn:
+ warnings.warn("Stabilized Sinkhorn did not converge." +
+ "Try a larger entropy `reg`" +
+ "Or a larger absorption threshold `tau`.")
if log:
- log['niter'] = cpt
- log['logu'] = nx.log(u + 1e-16)
- log['logv'] = nx.log(v + 1e-16)
+ log['niter'] = ii
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
return q, log
else:
return q
-def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
- stopThr=1e-9, stabThr=1e-30, verbose=False,
- log=False):
- r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}`
- where :math:`\mathbf{A}` is a collection of 2D images.
+def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
+ r"""Compute the debiased Sinkhorn barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence
+ (see :py:func:`ot.bregman.emirical_sinkhorn_divergence`)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix
+ :math:`\mathbf{A}`
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and
+ the cost matrix for OT
+
+ The algorithm used for solving the problem is the debiased Sinkhorn
+ algorithm as proposed in :ref:`[37] <references-sinkhorn-debiased>`
+
+ Parameters
+ ----------
+ A : array-like, shape (dim, n_hists)
+ `n_hists` training distributions :math:`a_i` of size `dim`
+ M : array-like, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ method : str (optional)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
+ weights : array-like, shape (n_hists,)
+ Weights of each histogram :math:`a_i` on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+
+
+ Returns
+ -------
+ a : (dim,) array-like
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ .. _references-sinkhorn-debiased:
+ References
+ ----------
+
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
+ Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _barycenter_debiased(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _barycenter_debiased_log(A, M, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False, warn=True):
+ r"""Compute the debiased sinkhorn barycenter of distributions A.
+ """
+
+ A, M = list_to_array(A, M)
+
+ nx = get_backend(A, M)
+
+ if weights is None:
+ weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ K = nx.exp(-M / reg)
+
+ err = 1
+
+ UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)
+
+ u = (geometricMean(UKv) / UKv.T).T
+ c = nx.ones(A.shape[0], type_as=A)
+ bar = nx.ones(A.shape[0], type_as=A)
+
+ for ii in range(numItermax):
+ bold = bar
+ UKv = nx.dot(K, A / nx.dot(K, u))
+ bar = c * geometricBar(weights, UKv)
+ u = bar[:, None] / UKv
+ c = (c * bar / nx.dot(K, c)) ** 0.5
+
+ if ii % 10 == 9:
+ err = abs(bar - bold).max() / max(bar.max(), 1.)
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ # debiased Sinkhorn does not converge monotonically
+ # guarantee a few iterations are done before stopping
+ if err < stopThr and ii > 20:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return bar, log
+ else:
+ return bar
+
+
+def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False,
+ warn=True):
+ r"""Compute the debiased sinkhorn barycenter in log domain.
+ """
+
+ A, M = list_to_array(A, M)
+ dim, n_hists = A.shape
+
+ nx = get_backend(A, M)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ if weights is None:
+ weights = nx.ones(n_hists, type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ M = - M / reg
+ logA = nx.log(A + 1e-15)
+ log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
+ c = nx.zeros(dim, type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros(dim, type_as=A)
+ for k in range(n_hists):
+ f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
+ log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
+ log_bar += weights[k] * log_KU[:, k]
+ log_bar += c
+ if ii % 10 == 1:
+ err = nx.exp(G + log_KU).std(axis=1).sum()
+
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if err < stopThr and ii > 20:
+ break
+ if verbose:
+ if ii % 200 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+
+ G = log_bar[:, None] - log_KU
+ for _ in range(10):
+ c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0))
+
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
+def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False,
+ warn=True, **kwargs):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
The function solves the following optimization problem:
@@ -1554,11 +1944,14 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
where :
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
- - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}`
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions
+ of matrix :math:`\mathbf{A}`
- `reg` is the regularization strength scalar value
- The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] <references-convolutional-barycenter-2d>`
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm
+ as proposed in :ref:`[21] <references-convolutional-barycenter-2d>`
Parameters
----------
@@ -1568,6 +1961,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
Regularization term >0
weights : array-like, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
+ method : string, optional
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
numItermax : int, optional
Max number of iterations
stopThr : float, optional
@@ -1578,6 +1973,8 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -1591,9 +1988,36 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
References
----------
- .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66
+ .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher,
+ A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances:
+ Efficient optimal transportation on geometric domains. ACM Transactions
+ on Graphics (TOG), 34(4), 66
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
+ International Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
+ if method.lower() == 'sinkhorn':
+ return _convolutional_barycenter2d(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _convolutional_barycenter2d_log(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-9, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images.
"""
A = list_to_array(A)
@@ -1608,65 +2032,373 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
if log:
log = {'err': []}
- b = nx.zeros(A.shape[1:], type_as=A)
+ bar = nx.ones(A.shape[1:], type_as=A)
+ bar /= bar.sum()
U = nx.ones(A.shape, type_as=A)
- KV = nx.ones(A.shape, type_as=A)
-
- cpt = 0
+ V = nx.ones(A.shape, type_as=A)
err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, A.shape[1])
[Y, X] = nx.meshgrid(t, t)
- xi1 = nx.exp(-(X - Y) ** 2 / reg)
+ K1 = nx.exp(-(X - Y) ** 2 / reg)
t = nx.linspace(0, 1, A.shape[2])
[Y, X] = nx.meshgrid(t, t)
- xi2 = nx.exp(-(X - Y) ** 2 / reg)
-
- def K(x):
- return nx.dot(nx.dot(xi1, x), xi2)
-
- while (err > stopThr and cpt < numItermax):
-
- bold = b
- cpt = cpt + 1
-
- b = nx.zeros(A.shape[1:], type_as=A)
- KV_cols = []
- for r in range(A.shape[0]):
- KV_col_r = K(A[r, :, :] / nx.maximum(stabThr, K(U[r, :, :])))
- b += weights[r] * nx.log(nx.maximum(stabThr, U[r, :, :] * KV_col_r))
- KV_cols.append(KV_col_r)
- KV = nx.stack(KV_cols)
- b = nx.exp(b)
-
- U = nx.stack([
- b / nx.maximum(stabThr, KV[r, :, :])
- for r in range(A.shape[0])
- ])
- if cpt % 10 == 1:
- err = nx.sum(nx.abs(bold - b))
+ K2 = nx.exp(-(X - Y) ** 2 / reg)
+
+ def convol_imgs(imgs):
+ kx = nx.einsum("...ij,kjl->kil", K1, imgs)
+ kxy = nx.einsum("...ij,klj->kli", K2, kx)
+ return kxy
+
+ KU = convol_imgs(U)
+ for ii in range(numItermax):
+ V = bar[None] / KU
+ KV = convol_imgs(V)
+ U = A / KV
+ KU = convol_imgs(U)
+ bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+ if ii % 10 == 9:
+ err = (V * KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ log['U'] = U
+ return bar, log
+ else:
+ return bar
+
+
+def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-4, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ where A is a collection of 2D images in log-domain.
+ """
+
+ A = list_to_array(A)
+
+ nx = get_backend(A)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+
+ n_hists, width, height = A.shape
+
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == n_hists)
+
+ if log:
+ log = {'err': []}
+
+ err = 1
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ M1 = - (X - Y) ** 2 / reg
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ M2 = - (X - Y) ** 2 / reg
+
+ def convol_img(log_img):
+ log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
+ log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
+ return log_img
+
+ logA = nx.log(A + stabThr)
+ log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros((width, height), type_as=A)
+ for k in range(n_hists):
+ f = logA[k] - convol_img(G[k])
+ log_KU[k] = convol_img(f)
+ log_bar = log_bar + weights[k] * log_KU[k]
+
+ if ii % 10 == 9:
+ err = nx.exp(G + log_KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ G = log_bar[None, :, :] - log_KU
+
+ else:
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
+
+
+def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn",
+ numItermax=10000, stopThr=1e-3,
+ verbose=False, log=False, warn=True,
+ **kwargs):
+ r"""Compute the debiased sinkhorn barycenter of distributions A
+ where A is a collection of 2D images.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein
+ distance (see :py:func:`ot.bregman.sinkhorn_debiased`)
+ - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two
+ dimensions of matrix :math:`\mathbf{A}`
+ - `reg` is the regularization strength scalar value
+
+ The algorithm used for solving the problem is the debiased Sinkhorn scaling
+ algorithm as proposed in :ref:`[37] <references-sinkhorn-debiased>`
+
+ Parameters
+ ----------
+ A : array-like, shape (n_hists, width, height)
+ `n` distributions (2D images) of size `width` x `height`
+ reg : float
+ Regularization term >0
+ weights : array-like, shape (n_hists,)
+ Weights of each image on the simplex (barycentric coodinates)
+ method : string, optional
+ method used for the solver either 'sinkhorn' or 'sinkhorn_log'
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshold on error (> 0)
+ stabThr : float, optional
+ Stabilization threshold to avoid numerical precision issue
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
+
+
+ Returns
+ -------
+ a : array-like, shape (width, height)
+ 2D Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ .. _references-sinkhorn-debiased:
+ References
+ ----------
+
+ .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
+ Conference on Machine Learning, PMLR 119:4692-4701, 2020
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _convolutional_barycenter2d_debiased(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_log':
+ return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, warn=warn,
+ **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-3, stabThr=1e-15, verbose=False,
+ log=False, warn=True):
+ r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions.
+ """
+
+ A = list_to_array(A)
+ n_hists, width, height = A.shape
+
+ nx = get_backend(A)
+
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == n_hists)
+
+ if log:
+ log = {'err': []}
+
+ bar = nx.ones((width, height), type_as=A)
+ bar /= width * height
+ U = nx.ones(A.shape, type_as=A)
+ V = nx.ones(A.shape, type_as=A)
+ c = nx.ones(A.shape[1:], type_as=A)
+ err = 1
+
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ K1 = nx.exp(-(X - Y) ** 2 / reg)
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ K2 = nx.exp(-(X - Y) ** 2 / reg)
+
+ def convol_imgs(imgs):
+ kx = nx.einsum("...ij,kjl->kil", K1, imgs)
+ kxy = nx.einsum("...ij,klj->kli", K2, kx)
+ return kxy
+
+ KU = convol_imgs(U)
+ for ii in range(numItermax):
+ V = bar[None] / KU
+ KV = convol_imgs(V)
+ U = A / KV
+ KU = convol_imgs(U)
+ bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0))
+
+ for _ in range(10):
+ c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5
+
+ if ii % 10 == 9:
+ err = (V * KU).std(axis=0).sum()
# log and verbose print
if log:
log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(ii, err))
+ # debiased Sinkhorn does not converge monotonically
+ # guarantee a few iterations are done before stopping
+ if err < stopThr and ii > 20:
+ break
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['U'] = U
- return b, log
+ return bar, log
+ else:
+ return bar
+
+
+def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000,
+ stopThr=1e-3, stabThr=1e-30, verbose=False,
+ log=False, warn=True):
+ r"""Compute the debiased barycenter of 2D images in log-domain.
+ """
+
+ A = list_to_array(A)
+ n_hists, width, height = A.shape
+ nx = get_backend(A)
+ if nx.__name__ == "jax":
+ raise NotImplementedError("Log-domain functions are not yet implemented"
+ " for Jax. Use numpy or torch arrays instead.")
+ if weights is None:
+ weights = nx.ones((n_hists,), type_as=A) / n_hists
+ else:
+ assert (len(weights) == A.shape[0])
+
+ if log:
+ log = {'err': []}
+
+ err = 1
+ # build the convolution operator
+ # this is equivalent to blurring on horizontal then vertical directions
+ t = nx.linspace(0, 1, width)
+ [Y, X] = nx.meshgrid(t, t)
+ M1 = - (X - Y) ** 2 / reg
+
+ t = nx.linspace(0, 1, height)
+ [Y, X] = nx.meshgrid(t, t)
+ M2 = - (X - Y) ** 2 / reg
+
+ def convol_img(log_img):
+ log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
+ log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
+ return log_img
+
+ logA = nx.log(A + stabThr)
+ log_bar, c = nx.zeros((2, width, height), type_as=A)
+ log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
+ err = 1
+ for ii in range(numItermax):
+ log_bar = nx.zeros((width, height), type_as=A)
+ for k in range(n_hists):
+ f = logA[k] - convol_img(G[k])
+ log_KU[k] = convol_img(f)
+ log_bar = log_bar + weights[k] * log_KU[k]
+ log_bar += c
+ for _ in range(10):
+ c = 0.5 * (c + log_bar - convol_img(c))
+
+ if ii % 10 == 9:
+ err = nx.exp(G + log_KU).std(axis=0).sum()
+ # log and verbose print
+ if log:
+ log['err'].append(err)
+
+ if verbose:
+ if ii % 200 == 0:
+ print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr and ii > 20:
+ break
+ G = log_bar[None, :, :] - log_KU
+
else:
- return b
+ if warn:
+ warnings.warn("Convolutional Sinkhorn did not converge. "
+ "Try a larger number of iterations `numItermax` "
+ "or a larger entropy `reg`.")
+ if log:
+ log['niter'] = ii
+ return nx.exp(log_bar), log
+ else:
+ return nx.exp(log_bar)
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
- stopThr=1e-3, verbose=False, log=False):
+ stopThr=1e-3, verbose=False, log=False, warn=True):
r"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
@@ -1679,16 +2411,21 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
where :
- - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`)
- - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, its expected shape is `(dim_a, n_atoms)`
+ - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ with M loss matrix (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`,
+ its expected shape is `(dim_a, n_atoms)`
- :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
- :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
- :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior`
- - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix (`dim_a`, `dim_a`) for OT data fitting
- - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization term and the cost matrix (`dim_prior`, `n_atoms`) regularization
+ - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the
+ cost matrix (`dim_a`, `dim_a`) for OT data fitting
+ - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization
+ term and the cost matrix (`dim_prior`, `n_atoms`) regularization
- :math:`\\alpha` weight data fitting and regularization
- The optimization problem is solved following the algorithm described in :ref:`[4] <references-unmix>`
+ The optimization problem is solved following the algorithm described
+ in :ref:`[4] <references-unmix>`
Parameters
@@ -1717,7 +2454,8 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Print information along iterations
log : bool, optional
record log if True
-
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -1731,8 +2469,10 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
References
----------
- .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.
-
+ .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
+ Supervised planetary unmixing with optimal transport, Whorkshop
+ on Hyperspectral Image and Signal Processing :
+ Evolution in Remote Sensing (WHISPERS), 2016.
"""
a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0)
@@ -1747,12 +2487,11 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
old = h0
err = 1
- cpt = 0
# log = {'niter':0, 'all_err':[]}
if log:
log = {'err': []}
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
K = projC(K, a)
K0 = projC(K0, h0)
new = nx.sum(K0, axis=1)
@@ -1770,22 +2509,27 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- cpt = cpt + 1
-
+ print('{:5d}|{:8e}|'.format(ii, err))
+ if err < stopThr:
+ break
+ else:
+ if warn:
+ warnings.warn("Unmixing algorithm did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
- log['niter'] = cpt
+ log['niter'] = ii
return nx.sum(K0, axis=1), log
else:
return nx.sum(K0, axis=1)
def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
- stopThr=1e-6, verbose=False, log=False, **kwargs):
- r'''Joint OT and proportion estimation for multi-source target shift as proposed in :ref:`[27] <references-jcpot-barycenter>`
+ stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
+ r'''Joint OT and proportion estimation for multi-source target shift as
+ proposed in :ref:`[27] <references-jcpot-barycenter>`
The function solves the following optimization problem:
@@ -1799,16 +2543,23 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
where :
- :math:`\lambda_k` is the weight of `k`-th source domain
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`)
- - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source domain and `C` is the number of classes
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
+ (see :py:func:`ot.bregman.sinkhorn`)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain
+ defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape
+ is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source
+ domain and `C` is the number of classes
- :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C`
- :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
- - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in
+ [p. 5, :ref:`27 <references-jcpot-barycenter>`], its expected shape is :math:`(n_k, C)`
- The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
+ The problem consist in solving a Wasserstein barycenter problem to estimate
+ the proportions :math:`\mathbf{h}` in the target domain.
The algorithm used for solving the problem is the Iterative Bregman projections algorithm
- with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution.
+ with two sets of marginal constraints related to the unknown vector
+ :math:`\mathbf{h}` and uniform target distribution.
Parameters
----------
@@ -1826,10 +2577,12 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Max number of iterations
stopThr : float, optional
Stop threshold on relative change in the barycenter (>0)
- log : bool, optional
- record log if True
verbose : bool, optional (default=False)
Controls the verbosity of the optimization algorithm
+ log : bool, optional
+ record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -1844,9 +2597,8 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
----------
.. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
- "Optimal transport for multi-source domain adaptation under target shift",
- International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
-
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
'''
Xs = list_to_array(*Xs)
@@ -1901,11 +2653,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
# uniform target distribution
a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0])
- cpt = 0 # iterations count
err = 1
old_bary = nx.ones((nbclasses,), type_as=Xs[0])
- while (err > stopThr and cpt < numItermax):
+ for ii in range(numItermax):
bary = nx.zeros((nbclasses,), type_as=Xs[0])
@@ -1923,21 +2674,27 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
K[d] = projR(K[d], new)
err = nx.norm(bary - old_bary)
- cpt = cpt + 1
+
old_bary = bary
if log:
log['err'].append(err)
+ if err < stopThr:
+ break
if verbose:
- if cpt % 200 == 0:
+ if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
+ print('{:5d}|{:8e}|'.format(ii, err))
+ else:
+ if warn:
+ warnings.warn("Algorithm did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
bary = bary / nx.sum(bary)
if log:
- log['niter'] = cpt
+ log['niter'] = ii
log['M'] = M
log['D1'] = D1
log['D2'] = D2
@@ -1949,7 +2706,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
- log=False, **kwargs):
+ log=False, warn=True, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -1967,7 +2724,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
where :
- :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
@@ -1988,7 +2746,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
stopThr : float, optional
Stop threshold on error (>0)
isLazy: boolean, optional
- If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function.
+ If True, then only calculate the cost matrix by block and return
+ the dual potentials only (to save memory). If False, calculate full
+ cost matrix and return outputs of sinkhorn function.
batchSize: int or tuple of 2 int, optional
Size of the batches used to compute the sinkhorn update without memory overhead.
When a tuple is provided it sets the size of the left/right batches.
@@ -1996,6 +2756,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
@@ -2021,11 +2783,14 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''
X_s, X_t = list_to_array(X_s, X_t)
@@ -2100,7 +2865,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
if err <= stopThr:
break
-
+ else:
+ if warn:
+ warnings.warn("Sinkhorn did not converge. You might want to "
+ "increase the number of iterations `numItermax` "
+ "or the regularization parameter `reg`.")
if log:
dict_log["u"] = f
dict_log["v"] = g
@@ -2111,15 +2880,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
else:
M = dist(X_s, X_t, metric=metric)
if log:
- pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
+ pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=True, **kwargs)
return pi, log
else:
- pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
+ pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
+ verbose=verbose, log=False, **kwargs)
return pi
-def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- isLazy=False, batchSize=100, verbose=False, log=False, **kwargs):
+def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9, isLazy=False,
+ batchSize=100, verbose=False, log=False, warn=True, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -2138,7 +2910,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
where :
- :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
@@ -2159,7 +2932,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
stopThr : float, optional
Stop threshold on error (>0)
isLazy: boolean, optional
- If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory). If False, calculate full cost matrix and return outputs of sinkhorn function.
+ If True, then only calculate the cost matrix by block and return
+ the dual potentials only (to save memory). If False, calculate
+ full cost matrix and return outputs of sinkhorn function.
batchSize: int or tuple of 2 int, optional
Size of the batches used to compute the sinkhorn update without memory overhead.
When a tuple is provided it sets the size of the left/right batches.
@@ -2167,6 +2942,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
@@ -2192,11 +2969,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
References
----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
+ of Optimal Transport, Advances in Neural Information
+ Processing Systems (NIPS) 26, 2013
- .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
+ .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling
+ Algorithms for Entropy Regularized Transport Problems.
+ arXiv preprint arXiv:1610.06519.
- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+ .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ Scaling algorithms for unbalanced transport problems.
+ arXiv preprint arXiv:1607.05816.
'''
X_s, X_t = list_to_array(X_s, X_t)
@@ -2211,11 +2994,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
if isLazy:
if log:
- f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
- isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
+ f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
+ numIterMax=numIterMax,
+ stopThr=stopThr,
+ isLazy=isLazy,
+ batchSize=batchSize,
+ verbose=verbose, log=log,
+ warn=warn)
else:
- f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
- isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
+ f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
+ numIterMax=numIterMax, stopThr=stopThr,
+ isLazy=isLazy, batchSize=batchSize,
+ verbose=verbose, log=log,
+ warn=warn)
bs = batchSize if isinstance(batchSize, int) else batchSize[0]
range_s = range(0, ns, bs)
@@ -2241,17 +3032,21 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
M = nx.from_numpy(M, type_as=a)
if log:
- sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
+ sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
return sinkhorn_loss, log
else:
- sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
- **kwargs)
+ sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ warn=warn, **kwargs)
return sinkhorn_loss
-def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
- verbose=False, log=False, **kwargs):
+def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
+ numIterMax=10000, stopThr=1e-9,
+ verbose=False, log=False, warn=True,
+ **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data
@@ -2288,8 +3083,11 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
\gamma_b &\geq 0
where :
- - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) is the (`n_samples_a`, `n_samples_b`) metric cost matrix (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`))
- - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`)
+ is the (`n_samples_a`, `n_samples_b`) metric cost matrix
+ (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`))
+ - :math:`\Omega` is the entropic regularization term
+ :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
@@ -2313,6 +3111,8 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
Print information along iterations
log : bool, optional
record log if True
+ warn : bool, optional
+ if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
@@ -2334,17 +3134,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
References
----------
- .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
+ .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative
+ Models with Sinkhorn Divergences, Proceedings of the Twenty-First
+ International Conference on Artficial Intelligence and Statistics,
+ (AISTATS) 21, 2018
'''
if log:
- sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
- sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax,
- stopThr=1e-9, verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
+ numIterMax=numIterMax,
+ stopThr=1e-9, verbose=verbose,
+ log=log, warn=warn, **kwargs)
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
@@ -2359,25 +3168,33 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
return max(0, sinkhorn_div), log
else:
- sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
- sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
- sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9,
- verbose=verbose, log=log, **kwargs)
+ sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
+ numIterMax=numIterMax, stopThr=1e-9,
+ verbose=verbose, log=log,
+ warn=warn, **kwargs)
sinkhorn_div = sinkhorn_loss_ab - 0.5 * (sinkhorn_loss_a + sinkhorn_loss_b)
return max(0, sinkhorn_div)
-def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, restricted=True,
- maxiter=10000, maxfun=10000, pgtol=1e-09, verbose=False, log=False):
+def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
+ restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09,
+ verbose=False, log=False):
r"""
Screening Sinkhorn Algorithm for Regularized Optimal Transport
- The function solves an approximate dual of Sinkhorn divergence :ref:`[2] <references-screenkhorn>` which is written as the following optimization problem:
+ The function solves an approximate dual of Sinkhorn divergence :ref:`[2]
+ <references-screenkhorn>` which is written as the following optimization problem:
.. math::
@@ -2395,56 +3212,49 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\}
- The parameters `kappa` and `epsilon` are determined w.r.t the couple number budget of points (`ns_budget`, `nt_budget`), see Equation (5) in :ref:`[26] <references-screenkhorn>`
+ The parameters `kappa` and `epsilon` are determined w.r.t the couple number
+ budget of points (`ns_budget`, `nt_budget`), see Equation (5)
+ in :ref:`[26] <references-screenkhorn>`
Parameters
----------
- a : array-like, shape=(ns,)
+ a: array-like, shape=(ns,)
samples weights in the source domain
-
- b : array-like, shape=(nt,)
+ b: array-like, shape=(nt,)
samples weights in the target domain
-
- M : array-like, shape=(ns, nt)
+ M: array-like, shape=(ns, nt)
Cost matrix
-
- reg : `float`
+ reg: `float`
Level of the entropy regularisation
-
- ns_budget : `int`, default=None
+ ns_budget: `int`, default=None
Number budget of points to be kept in the source domain.
If it is None then 50% of the source sample points will be kept
-
- nt_budget : `int`, default=None
+ nt_budget: `int`, default=None
Number budget of points to be kept in the target domain.
If it is None then 50% of the target sample points will be kept
-
- uniform : `bool`, default=False
- If `True`, the source and target distribution are supposed to be uniform, i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt`
-
+ uniform: `bool`, default=False
+ If `True`, the source and target distribution are supposed to be uniform,
+ i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt`
restricted : `bool`, default=True
If `True`, a warm-start initialization for the L-BFGS-B solver
using a restricted Sinkhorn algorithm with at most 5 iterations
-
- maxiter : `int`, default=10000
+ maxiter: `int`, default=10000
Maximum number of iterations in LBFGS solver
-
- maxfun : `int`, default=10000
+ maxfun: `int`, default=10000
Maximum number of function evaluations in LBFGS solver
-
- pgtol : `float`, default=1e-09
+ pgtol: `float`, default=1e-09
Final objective function accuracy in LBFGS solver
-
- verbose : `bool`, default=False
- If `True`, display informations about the cardinals of the active sets and the parameters kappa
- and epsilon
-
+ verbose: `bool`, default=False
+ If `True`, display informations about the cardinals of the active sets
+ and the parameters kappa and epsilon
Dependency
----------
- To gain more efficiency, screenkhorn needs to call the "Bottleneck" package (https://pypi.org/project/Bottleneck/)
- in the screening pre-processing step. If Bottleneck isn't installed, the following error message appears:
+ To gain more efficiency, screenkhorn needs to call the "Bottleneck"
+ package (https://pypi.org/project/Bottleneck/)
+ in the screening pre-processing step. If Bottleneck isn't installed,
+ the following error message appears:
"Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"
@@ -2461,9 +3271,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
References
-----------
- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport,
+ Advances in Neural Information Processing Systems (NIPS) 26, 2013
- .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
+ .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019).
+ Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019
"""
# check if bottleneck module exists
@@ -2471,14 +3283,16 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
import bottleneck
except ImportError:
warnings.warn(
- "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.")
+ "Bottleneck module is not installed. Install it from"
+ " https://pypi.org/project/Bottleneck/ for better performance.")
bottleneck = np
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if nx.__name__ == "jax":
- raise TypeError("JAX arrays have been received but screenkhorn is not compatible with JAX.")
+ raise TypeError("JAX arrays have been received but screenkhorn is not "
+ "compatible with JAX.")
ns, nt = M.shape
@@ -2582,7 +3396,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
if verbose:
print("epsilon = %s\n" % epsilon)
print("kappa = %s\n" % kappa)
- print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' % (sum(Isel), sum(Jsel)))
+ print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n'
+ % (sum(Isel), sum(Jsel)))
# Ic, Jc: complementary of the active sets I and J
Ic = ~Isel
@@ -2638,13 +3453,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1)
cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa
- cpt = 1
- while cpt < 5: # 5 iterations
+ for _ in range(5): # 5 iterations
K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v
v0 = b_J / (kappa * K_IJ_v)
KIJ_u = nx.dot(K_IJ, v0) + cst_u
u0 = (kappa * a_I) / KIJ_u
- cpt += 1
u0 = projection(u0, epsilon / kappa)
v0 = projection(v0, epsilon * kappa)
@@ -2655,15 +3468,13 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res
def restricted_sinkhorn(usc, vsc, max_iter=5):
"""
- Restricted Sinkhorn Algorithm as a warm-start initialized point for L-BFGS-B (see Algorithm 1 in supplementary of [26])
+ Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B)
"""
- cpt = 1
- while cpt < max_iter:
+ for _ in range(max_iter):
K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v
vsc = b_J / (kappa * K_IJ_v)
KIJ_u = nx.dot(K_IJ, vsc) + cst_u
usc = (kappa * a_I) / KIJ_u
- cpt += 1
usc = projection(usc, epsilon / kappa)
vsc = projection(vsc, epsilon * kappa)
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 6923d31..edfe9c3 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -6,6 +6,8 @@
#
# License: MIT License
+from itertools import product
+
import numpy as np
import pytest
@@ -13,7 +15,8 @@ import ot
from ot.backend import torch
-def test_sinkhorn():
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn(verbose, warn):
# test sinkhorn
n = 100
rng = np.random.RandomState(0)
@@ -23,7 +26,7 @@ def test_sinkhorn():
M = ot.dist(x, x)
- G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
+ G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10, verbose=verbose, warn=warn)
# check constraints
np.testing.assert_allclose(
@@ -31,8 +34,92 @@ def test_sinkhorn():
np.testing.assert_allclose(
u, G.sum(0), atol=1e-05) # cf convergence sinkhorn
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(u, u, M, 1, stopThr=0, numItermax=1)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_epsilon_scaling",
+ "greenkhorn",
+ "sinkhorn_log"])
+def test_convergence_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ A = np.asarray([a1, a2]).T
+ M = ot.utils.dist0(n)
+
+ with pytest.warns(UserWarning):
+ ot.sinkhorn(a1, a2, M, 1., method=method, stopThr=0, numItermax=1)
+
+ if method in ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"]:
+ with pytest.warns(UserWarning):
+ ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1)
+
+
+def test_not_impemented_method():
+ # test sinkhorn
+ w = 10
+ n = w ** 2
+ rng = np.random.RandomState(42)
+ A_img = rng.rand(2, w, w)
+ A_flat = A_img.reshape(n, 2)
+ a1, a2 = A_flat.T
+ M_flat = ot.utils.dist0(n)
+ not_implemented = "new_method"
+ reg = 0.01
+ with pytest.raises(ValueError):
+ ot.sinkhorn(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.sinkhorn2(a1, a2, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.barycenter(A_flat, M_flat, reg, method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.barycenter_debiased(A_flat, M_flat, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d(A_img, reg,
+ method=not_implemented)
+ with pytest.raises(ValueError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg,
+ method=not_implemented)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
+def test_nan_warning(method):
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+
+ M = ot.utils.dist0(n)
+ reg = 0
+ with pytest.warns(UserWarning):
+ # warn set to False to avoid catching a convergence warning instead
+ ot.sinkhorn(a1, a2, M, reg, method=method, warn=False)
+
+
+def test_sinkhorn_stabilization():
+ # test sinkhorn
+ n = 100
+ a1 = ot.datasets.make_1D_gauss(n, m=30, s=10)
+ a2 = ot.datasets.make_1D_gauss(n, m=40, s=10)
+ M = ot.utils.dist0(n)
+ reg = 1e-5
+ loss1 = ot.sinkhorn2(a1, a2, M, reg, method="sinkhorn_log")
+ loss2 = ot.sinkhorn2(a1, a2, M, reg, tau=1, method="sinkhorn_stabilized")
+ np.testing.assert_allclose(
+ loss1, loss2, atol=1e-06) # cf convergence sinkhorn
+
-def test_sinkhorn_multi_b():
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized",
+ "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_sinkhorn_multi_b(method, verbose, warn):
# test sinkhorn
n = 10
rng = np.random.RandomState(0)
@@ -45,12 +132,14 @@ def test_sinkhorn_multi_b():
M = ot.dist(x, x)
- loss0, log = ot.sinkhorn(u, b, M, .1, stopThr=1e-10, log=True)
+ loss0, log = ot.sinkhorn(u, b, M, .1, method=method, stopThr=1e-10,
+ log=True)
- loss = [ot.sinkhorn2(u, b[:, k], M, .1, stopThr=1e-10) for k in range(3)]
+ loss = [ot.sinkhorn2(u, b[:, k], M, .1, method=method, stopThr=1e-10,
+ verbose=verbose, warn=warn) for k in range(3)]
# check constraints
np.testing.assert_allclose(
- loss0, loss, atol=1e-06) # cf convergence sinkhorn
+ loss0, loss, atol=1e-4) # cf convergence sinkhorn
def test_sinkhorn_backends(nx):
@@ -67,9 +156,9 @@ def test_sinkhorn_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
- Gb = ot.sinkhorn(ab, ab, Mb, 1)
+ Gb = ot.sinkhorn(ab, ab, M_nx, 1)
np.allclose(G, nx.to_numpy(Gb))
@@ -88,9 +177,9 @@ def test_sinkhorn2_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
- Gb = ot.sinkhorn2(ab, ab, Mb, 1)
+ Gb = ot.sinkhorn2(ab, ab, M_nx, 1)
np.allclose(G, nx.to_numpy(Gb))
@@ -131,6 +220,12 @@ def test_sinkhorn_empty():
M = ot.dist(x, x)
+ G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, method="sinkhorn_log",
+ verbose=True, log=True)
+ # check constraints
+ np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
+ np.testing.assert_allclose(u, G.sum(0), atol=1e-05)
+
G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
# check constraints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
@@ -165,15 +260,15 @@ def test_sinkhorn_variants(nx):
M = ot.dist(x, x)
ub = nx.from_numpy(u)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
Ges = nx.to_numpy(ot.sinkhorn(
- ub, ub, Mb, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
- G_green = nx.to_numpy(ot.sinkhorn(ub, ub, Mb, 1, method='greenkhorn', stopThr=1e-10))
+ ub, ub, M_nx, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10))
+ G_green = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='greenkhorn', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -199,12 +294,12 @@ def test_sinkhorn_variants_multi_b(nx):
ub = nx.from_numpy(u)
bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -228,12 +323,12 @@ def test_sinkhorn2_variants_multi_b(nx):
ub = nx.from_numpy(u)
bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
- Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_log', stopThr=1e-10))
- G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn', stopThr=1e-10))
- Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, Mb, 1, method='sinkhorn_stabilized', stopThr=1e-10))
+ Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
+ G0 = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn', stopThr=1e-10))
+ Gs = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_stabilized', stopThr=1e-10))
# check values
np.testing.assert_allclose(G, G0, atol=1e-05)
@@ -255,7 +350,7 @@ def test_sinkhorn_variants_log():
Gl, logl = ot.sinkhorn(u, u, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
Gs, logs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
Ges, loges = ot.sinkhorn(
- u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True)
+ u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10, log=True,)
G_green, loggreen = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10, log=True)
# check values
@@ -265,7 +360,8 @@ def test_sinkhorn_variants_log():
np.testing.assert_allclose(G0, G_green, atol=1e-5)
-def test_sinkhorn_variants_log_multib():
+@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False]))
+def test_sinkhorn_variants_log_multib(verbose, warn):
# test sinkhorn
n = 50
rng = np.random.RandomState(0)
@@ -278,16 +374,20 @@ def test_sinkhorn_variants_log_multib():
M = ot.dist(x, x)
G0, log0 = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10, log=True)
- Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True)
- Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True)
+ Gl, logl = ot.sinkhorn(u, b, M, 1, method='sinkhorn_log', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
+ Gs, logs = ot.sinkhorn(u, b, M, 1, method='sinkhorn_stabilized', stopThr=1e-10, log=True,
+ verbose=verbose, warn=warn)
# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Gl, atol=1e-05)
-@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_barycenter(nx, method):
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter(nx, method, verbose, warn):
n_bins = 100 # nb bins
# Gaussian distributions
@@ -304,20 +404,98 @@ def test_barycenter(nx, method):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- Ab = nx.from_numpy(A)
- Mb = nx.from_numpy(M)
- weightsb = nx.from_numpy(weights)
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
+ reg = 1e-2
+
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter(A_nx, M_nx, reg, weights, method=method)
+ else:
+ # wasserstein
+ bary_wass_np = ot.bregman.barycenter(A, M, reg, weights, method=method, verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass))
+ np.testing.assert_allclose(bary_wass, bary_wass_np)
+
+ ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
+
+
+@pytest.mark.parametrize("method, verbose, warn",
+ product(["sinkhorn", "sinkhorn_log"],
+ [True, False], [True, False]))
+def test_barycenter_debiased(nx, method, verbose, warn):
+ n_bins = 100 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
+ weights_nx = nx.from_numpy(weights)
# wasserstein
reg = 1e-2
- bary_wass_np, log = ot.bregman.barycenter(A, M, reg, weights, method=method, log=True)
- bary_wass, _ = ot.bregman.barycenter(Ab, Mb, reg, weightsb, method=method, log=True)
- bary_wass = nx.to_numpy(bary_wass)
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights, method=method)
+ else:
+ bary_wass_np = ot.bregman.barycenter_debiased(A, M, reg, weights, method=method,
+ verbose=verbose, warn=warn)
+ bary_wass, _ = ot.bregman.barycenter_debiased(A_nx, M_nx, reg, weights_nx, method=method, log=True)
+ bary_wass = nx.to_numpy(bary_wass)
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), atol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-5)
+
+ ot.bregman.barycenter_debiased(A_nx, M_nx, reg, log=True, verbose=False)
- np.testing.assert_allclose(1, np.sum(bary_wass))
- np.testing.assert_allclose(bary_wass, bary_wass_np)
- ot.bregman.barycenter(Ab, Mb, reg, log=True, verbose=True)
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_convergence_warning_barycenters(method):
+ w = 10
+ n_bins = w ** 2 # nb bins
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n_bins, m=40, s=10)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ A_img = A.reshape(2, w, w)
+ A_img /= A_img.sum((1, 2))[:, None, None]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n_bins)
+ M /= M.max()
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+ reg = 0.1
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter_debiased(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.barycenter(A, M, reg, weights, method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d(A_img, reg, weights,
+ method=method, numItermax=1)
+ with pytest.warns(UserWarning):
+ ot.bregman.convolutional_barycenter2d_debiased(A_img, reg, weights,
+ method=method, numItermax=1)
def test_barycenter_stabilization(nx):
@@ -337,31 +515,64 @@ def test_barycenter_stabilization(nx):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- Ab = nx.from_numpy(A)
- Mb = nx.from_numpy(M)
+ A_nx = nx.from_numpy(A)
+ M_nx = nx.from_numpy(M)
weights_b = nx.from_numpy(weights)
# wasserstein
reg = 1e-2
bar_np = ot.bregman.barycenter(A, M, reg, weights, method="sinkhorn", stopThr=1e-8, verbose=True)
bar_stable = nx.to_numpy(ot.bregman.barycenter(
- Ab, Mb, reg, weights_b, method="sinkhorn_stabilized",
+ A_nx, M_nx, reg, weights_b, method="sinkhorn_stabilized",
stopThr=1e-8, verbose=True
))
bar = nx.to_numpy(ot.bregman.barycenter(
- Ab, Mb, reg, weights_b, method="sinkhorn",
+ A_nx, M_nx, reg, weights_b, method="sinkhorn",
stopThr=1e-8, verbose=True
))
np.testing.assert_allclose(bar, bar_stable)
np.testing.assert_allclose(bar, bar_np)
-def test_wasserstein_bary_2d(nx):
- size = 100 # size of a square image
- a1 = np.random.randn(size, size)
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
+ a1 += a1.min()
+ a1 = a1 / np.sum(a1)
+ a2 = np.random.rand(size, size)
+ a2 += a2.min()
+ a2 = a2 / np.sum(a2)
+ # creating matrix A containing all distributions
+ A = np.zeros((2, size, size))
+ A[0, :, :] = a1
+ A[1, :, :] = a2
+
+ A_nx = nx.from_numpy(A)
+
+ # wasserstein
+ reg = 1e-2
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, reg, method=method))
+
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
+
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+
+
+@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
+def test_wasserstein_bary_2d_debiased(nx, method):
+ size = 20 # size of a square image
+ a1 = np.random.rand(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
- a2 = np.random.randn(size, size)
+ a2 = np.random.rand(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
@@ -369,18 +580,22 @@ def test_wasserstein_bary_2d(nx):
A[0, :, :] = a1
A[1, :, :] = a2
- Ab = nx.from_numpy(A)
+ A_nx = nx.from_numpy(A)
# wasserstein
reg = 1e-2
- bary_wass_np = ot.bregman.convolutional_barycenter2d(A, reg)
- bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, reg))
+ if nx.__name__ == "jax" and method == "sinkhorn_log":
+ with pytest.raises(NotImplementedError):
+ ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method)
+ else:
+ bary_wass_np = ot.bregman.convolutional_barycenter2d_debiased(A, reg, method=method)
+ bary_wass = nx.to_numpy(ot.bregman.convolutional_barycenter2d_debiased(A_nx, reg, method=method))
- np.testing.assert_allclose(1, np.sum(bary_wass))
- np.testing.assert_allclose(bary_wass, bary_wass_np)
+ np.testing.assert_allclose(1, np.sum(bary_wass), rtol=1e-3)
+ np.testing.assert_allclose(bary_wass, bary_wass_np, atol=1e-3)
- # help in checking if log and verbose do not bug the function
- ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
+ # help in checking if log and verbose do not bug the function
+ ot.bregman.convolutional_barycenter2d(A, reg, log=True, verbose=True)
def test_unmix(nx):
@@ -405,20 +620,20 @@ def test_unmix(nx):
ab = nx.from_numpy(a)
Db = nx.from_numpy(D)
- Mb = nx.from_numpy(M)
+ M_nx = nx.from_numpy(M)
M0b = nx.from_numpy(M0)
h0b = nx.from_numpy(h0)
# wasserstein
reg = 1e-3
um_np = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01)
- um = nx.to_numpy(ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg, 1, alpha=0.01))
+ um = nx.to_numpy(ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg, 1, alpha=0.01))
np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)
np.testing.assert_allclose(um, um_np)
- ot.bregman.unmix(ab, Db, Mb, M0b, h0b, reg,
+ ot.bregman.unmix(ab, Db, M_nx, M0b, h0b, reg,
1, alpha=0.01, log=True, verbose=True)
@@ -437,22 +652,22 @@ def test_empirical_sinkhorn(nx):
bb = nx.from_numpy(b)
X_sb = nx.from_numpy(X_s)
X_tb = nx.from_numpy(X_t)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
M_mb = nx.from_numpy(M_m, type_as=ab)
G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
- sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1))
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
G_log, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, log=True)
G_log = nx.to_numpy(G_log)
- sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
G_m = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean'))
sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1))
loss_emp_sinkhorn = nx.to_numpy(ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1))
- loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1))
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
# check constraints
np.testing.assert_allclose(
@@ -486,18 +701,18 @@ def test_lazy_empirical_sinkhorn(nx):
bb = nx.from_numpy(b)
X_sb = nx.from_numpy(X_s)
X_tb = nx.from_numpy(X_t)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
M_mb = nx.from_numpy(M_m, type_as=ab)
f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
- sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1))
+ sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
f, g, log_es = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
- sinkhorn_log, log_s = ot.sinkhorn(ab, bb, Mb, 0.1, log=True)
+ sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True)
sinkhorn_log = nx.to_numpy(sinkhorn_log)
f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric='euclidean', numIterMax=numIterMax, isLazy=True, batchSize=1)
@@ -507,7 +722,7 @@ def test_lazy_empirical_sinkhorn(nx):
loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn)
- loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, Mb, 1))
+ loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1))
# check constraints
np.testing.assert_allclose(
@@ -541,13 +756,13 @@ def test_empirical_sinkhorn_divergence(nx):
bb = nx.from_numpy(b)
X_sb = nx.from_numpy(X_s)
X_tb = nx.from_numpy(X_t)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
M_sb = nx.from_numpy(M_s, type_as=ab)
M_tb = nx.from_numpy(M_t, type_as=ab)
emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
sinkhorn_div = nx.to_numpy(
- ot.sinkhorn2(ab, bb, Mb, 1)
+ ot.sinkhorn2(ab, bb, M_nx, 1)
- 1 / 2 * ot.sinkhorn2(ab, ab, M_sb, 1)
- 1 / 2 * ot.sinkhorn2(bb, bb, M_tb, 1)
)
@@ -580,14 +795,14 @@ def test_stabilized_vs_sinkhorn_multidim(nx):
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
- G, log = ot.bregman.sinkhorn(ab, bb, Mb, reg=epsilon,
+ G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
method="sinkhorn_stabilized",
log=True)
G = nx.to_numpy(G)
- G2, log2 = ot.bregman.sinkhorn(ab, bb, Mb, epsilon,
+ G2, log2 = ot.bregman.sinkhorn(ab, bb, M_nx, epsilon,
method="sinkhorn", log=True)
G2 = nx.to_numpy(G2)
@@ -642,14 +857,14 @@ def test_screenkhorn(nx):
ab = nx.from_numpy(a)
bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ M_nx = nx.from_numpy(M, type_as=ab)
# np sinkhorn
G_sink_np = ot.sinkhorn(a, b, M, 1e-03)
# sinkhorn
- G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, Mb, 1e-03))
+ G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-03))
# screenkhorn
- G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, Mb, 1e-03, uniform=True, verbose=True))
+ G_screen = nx.to_numpy(ot.bregman.screenkhorn(ab, bb, M_nx, 1e-03, uniform=True, verbose=True))
# check marginals
np.testing.assert_allclose(G_sink_np, G_sink)
np.testing.assert_allclose(G_sink.sum(0), G_screen.sum(0), atol=1e-02)
@@ -659,10 +874,10 @@ def test_screenkhorn(nx):
def test_convolutional_barycenter_non_square(nx):
# test for image with height not equal width
A = np.ones((2, 2, 3)) / (2 * 3)
- Ab = nx.from_numpy(A)
+ A_nx = nx.from_numpy(A)
b_np = ot.bregman.convolutional_barycenter2d(A, 1e-03)
- b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(Ab, 1e-03))
+ b = nx.to_numpy(ot.bregman.convolutional_barycenter2d(A_nx, 1e-03))
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)
np.testing.assert_allclose(np.ones((2, 3)) / (2 * 3), b, atol=1e-02)