summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authortvayer <titouan.vayer@gmail.com>2019-05-29 14:24:05 +0200
committertvayer <titouan.vayer@gmail.com>2019-05-29 14:24:05 +0200
commit63bbeb34e48f02c97a762dab5232158d90a5cffc (patch)
tree853026b5854b6e4b01fdf750db139985b3dd596f /examples
parentf70aabfcc11f92181e0dc987b341bad8ec030d75 (diff)
parentf66ab58c7c895011fd37bafd3e848828399c56c4 (diff)
Merge remote-tracking branch 'rflamary/master'
merge pot
Diffstat (limited to 'examples')
-rw-r--r--examples/plot_OT_2D_samples.py26
-rw-r--r--examples/plot_barycenter_1D.py8
-rw-r--r--examples/plot_convolutional_barycenter.py92
-rw-r--r--examples/plot_free_support_barycenter.py69
-rw-r--r--examples/plot_otda_color_images.py8
-rw-r--r--examples/plot_otda_mapping_colors_images.py2
-rw-r--r--examples/plot_stochastic.py208
7 files changed, 404 insertions, 9 deletions
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index bb952a0..63126ba 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -10,6 +10,7 @@ sum of diracs. The OT matrix is plotted with the samples.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
#
# License: MIT License
@@ -100,3 +101,28 @@ pl.legend(loc=0)
pl.title('OT matrix Sinkhorn with samples')
pl.show()
+
+
+##############################################################################
+# Emprirical Sinkhorn
+# ----------------
+
+#%% sinkhorn
+
+# reg term
+lambd = 1e-3
+
+Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
+
+pl.figure(7)
+pl.imshow(Ges, interpolation='nearest')
+pl.title('OT matrix empirical sinkhorn')
+
+pl.figure(8)
+ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix Sinkhorn from samples')
+
+pl.show()
diff --git a/examples/plot_barycenter_1D.py b/examples/plot_barycenter_1D.py
index 5ed9f3f..6864301 100644
--- a/examples/plot_barycenter_1D.py
+++ b/examples/plot_barycenter_1D.py
@@ -25,7 +25,7 @@ import ot
from mpl_toolkits.mplot3d import Axes3D # noqa
from matplotlib.collections import PolyCollection
-#
+##############################################################################
# Generate data
# -------------
@@ -48,7 +48,7 @@ n_distributions = A.shape[1]
M = ot.utils.dist0(n)
M /= M.max()
-#
+##############################################################################
# Plot data
# ---------
@@ -60,7 +60,7 @@ for i in range(n_distributions):
pl.title('Distributions')
pl.tight_layout()
-#
+##############################################################################
# Barycenter computation
# ----------------------
@@ -90,7 +90,7 @@ pl.legend()
pl.title('Barycenters')
pl.tight_layout()
-#
+##############################################################################
# Barycentric interpolation
# -------------------------
diff --git a/examples/plot_convolutional_barycenter.py b/examples/plot_convolutional_barycenter.py
new file mode 100644
index 0000000..e74db04
--- /dev/null
+++ b/examples/plot_convolutional_barycenter.py
@@ -0,0 +1,92 @@
+
+#%%
+# -*- coding: utf-8 -*-
+"""
+============================================
+Convolutional Wasserstein Barycenter example
+============================================
+
+This example is designed to illustrate how the Convolutional Wasserstein Barycenter
+function of POT works.
+"""
+
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import pylab as pl
+import ot
+
+##############################################################################
+# Data preparation
+# ----------------
+#
+# The four distributions are constructed from 4 simple images
+
+
+f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
+f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
+f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
+f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
+
+A = []
+f1 = f1 / np.sum(f1)
+f2 = f2 / np.sum(f2)
+f3 = f3 / np.sum(f3)
+f4 = f4 / np.sum(f4)
+A.append(f1)
+A.append(f2)
+A.append(f3)
+A.append(f4)
+A = np.array(A)
+
+nb_images = 5
+
+# those are the four corners coordinates that will be interpolated by bilinear
+# interpolation
+v1 = np.array((1, 0, 0, 0))
+v2 = np.array((0, 1, 0, 0))
+v3 = np.array((0, 0, 1, 0))
+v4 = np.array((0, 0, 0, 1))
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+pl.figure(figsize=(10, 10))
+pl.title('Convolutional Wasserstein Barycenters in POT')
+cm = 'Blues'
+# regularization parameter
+reg = 0.004
+for i in range(nb_images):
+ for j in range(nb_images):
+ pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
+ tx = float(i) / (nb_images - 1)
+ ty = float(j) / (nb_images - 1)
+
+ # weights are constructed by bilinear interpolation
+ tmp1 = (1 - tx) * v1 + tx * v2
+ tmp2 = (1 - tx) * v3 + tx * v4
+ weights = (1 - ty) * tmp1 + ty * tmp2
+
+ if i == 0 and j == 0:
+ pl.imshow(f1, cmap=cm)
+ pl.axis('off')
+ elif i == 0 and j == (nb_images - 1):
+ pl.imshow(f3, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == 0:
+ pl.imshow(f2, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == (nb_images - 1):
+ pl.imshow(f4, cmap=cm)
+ pl.axis('off')
+ else:
+ # call to barycenter computation
+ pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
+ pl.axis('off')
+pl.show()
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py
new file mode 100644
index 0000000..b6efc59
--- /dev/null
+++ b/examples/plot_free_support_barycenter.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+2D free support Wasserstein barycenters of distributions
+====================================================
+
+Illustration of 2D Wasserstein barycenters if discributions that are weighted
+sum of diracs.
+
+"""
+
+# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+#%% parameters and data generation
+N = 3
+d = 2
+measures_locations = []
+measures_weights = []
+
+for i in range(N):
+
+ n_i = np.random.randint(low=1, high=20) # nb samples
+
+ mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+
+ A_i = np.random.rand(d, d)
+ cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+
+ x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
+ b_i = np.random.uniform(0., 1., (n_i,))
+ b_i = b_i / np.sum(b_i) # Dirac weights
+
+ measures_locations.append(x_i)
+ measures_weights.append(b_i)
+
+
+##############################################################################
+# Compute free support barycenter
+# -------------
+
+k = 10 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
+
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1)
+for (x_i, b_i) in zip(measures_locations, measures_weights):
+ color = np.random.randint(low=1, high=10 * N)
+ pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc=0)
+pl.show()
diff --git a/examples/plot_otda_color_images.py b/examples/plot_otda_color_images.py
index e77aec0..62383a2 100644
--- a/examples/plot_otda_color_images.py
+++ b/examples/plot_otda_color_images.py
@@ -4,7 +4,7 @@
OT for image color adaptation
=============================
-This example presents a way of transferring colors between two image
+This example presents a way of transferring colors between two images
with Optimal Transport as introduced in [6]
[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).
@@ -27,7 +27,7 @@ r = np.random.RandomState(42)
def im2mat(I):
- """Converts and image to matrix (one pixel per line)"""
+ """Converts an image to matrix (one pixel per line)"""
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
@@ -115,8 +115,8 @@ ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
transp_Xs_emd = ot_emd.transform(Xs=X1)
transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
-transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
-transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)
+transp_Xt_sinkhorn = ot_sinkhorn.inverse_transform(Xt=X2)
I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
diff --git a/examples/plot_otda_mapping_colors_images.py b/examples/plot_otda_mapping_colors_images.py
index 5f1e844..a20eca8 100644
--- a/examples/plot_otda_mapping_colors_images.py
+++ b/examples/plot_otda_mapping_colors_images.py
@@ -77,7 +77,7 @@ Image_emd = minmax(mat2im(transp_Xs_emd, I1.shape))
# SinkhornTransport
ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
-transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)
Image_sinkhorn = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
ot_mapping_linear = ot.da.MappingTransport(
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py
new file mode 100644
index 0000000..742f8d9
--- /dev/null
+++ b/examples/plot_stochastic.py
@@ -0,0 +1,208 @@
+"""
+==========================
+Stochastic examples
+==========================
+
+This example is designed to show how to use the stochatic optimization
+algorithms for descrete and semicontinous measures from the POT library.
+
+"""
+
+# Author: Kilian Fatras <kilian.fatras@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+import ot
+import ot.plot
+
+
+#############################################################################
+# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
+#############################################################################
+#############################################################################
+# DISCRETE CASE:
+#
+# Sample two discrete measures for the discrete case
+# ---------------------------------------------
+#
+# Define 2 discrete measures a and b, the points where are defined the source
+# and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 1000
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "SAG" method to find the transportation matrix in the discrete case
+# ---------------------------------------------
+#
+# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the
+# results.
+
+method = "SAG"
+sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax)
+print(sag_pi)
+
+#############################################################################
+# SEMICONTINOUS CASE:
+#
+# Sample one general measure a, one discrete measures b for the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define one general measure a, one discrete measures b, the points where
+# are defined the source and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 1000
+log = True
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "ASGD" method to find the transportation matrix in the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the
+# results.
+
+method = "ASGD"
+asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax, log=log)
+print(log_asgd['alpha'], log_asgd['beta'])
+print(asgd_pi)
+
+#############################################################################
+#
+# Compare the results with the Sinkhorn algorithm
+# ---------------------------------------------
+#
+# Call the Sinkhorn algorithm from POT
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+print(sinkhorn_pi)
+
+
+##############################################################################
+# PLOT TRANSPORTATION MATRIX
+##############################################################################
+
+##############################################################################
+# Plot SAG results
+# ----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
+pl.show()
+
+
+##############################################################################
+# Plot ASGD results
+# -----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
+pl.show()
+
+
+##############################################################################
+# Plot Sinkhorn results
+# ---------------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()
+
+
+#############################################################################
+# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
+#############################################################################
+#############################################################################
+# SEMICONTINOUS CASE:
+#
+# Sample one general measure a, one discrete measures b for the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define one general measure a, one discrete measures b, the points where
+# are defined the source and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 100000
+lr = 0.1
+batch_size = 3
+log = True
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "SGD" dual method to find the transportation matrix in the
+# semicontinous case
+# ---------------------------------------------
+#
+# Call ot.solve_dual_entropic and plot the results.
+
+sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
+ batch_size, numItermax,
+ lr, log=log)
+print(log_sgd['alpha'], log_sgd['beta'])
+print(sgd_dual_pi)
+
+#############################################################################
+#
+# Compare the results with the Sinkhorn algorithm
+# ---------------------------------------------
+#
+# Call the Sinkhorn algorithm from POT
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+print(sinkhorn_pi)
+
+##############################################################################
+# Plot SGD results
+# -----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
+pl.show()
+
+
+##############################################################################
+# Plot Sinkhorn results
+# ---------------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()