summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2020-01-20 14:07:53 +0100
committerGard Spreemann <gspr@nonempty.org>2020-01-20 14:07:53 +0100
commitbdfb24ff37ea777d6e266b145047cd4e281ebac3 (patch)
tree00cbac5f3dc25a4ee76164828abd72c1cbab37cc /examples
parentabc441b00f0fe2fa4ef0efc4e1aa67b27cca9a13 (diff)
parent5e70a77fbb2feec513f21c9ef65dcc535329ace6 (diff)
Merge tag '0.6.0' into debian/sid
Diffstat (limited to 'examples')
-rw-r--r--examples/README.txt4
-rw-r--r--examples/plot_OT_1D.py84
-rw-r--r--examples/plot_OT_1D_smooth.py110
-rw-r--r--examples/plot_OT_2D_samples.py128
-rw-r--r--examples/plot_OT_L1_vs_L2.py208
-rw-r--r--examples/plot_UOT_1D.py76
-rw-r--r--examples/plot_UOT_barycenter_1D.py164
-rw-r--r--examples/plot_WDA.py127
-rw-r--r--examples/plot_barycenter_1D.py160
-rw-r--r--examples/plot_barycenter_fgw.py184
-rw-r--r--examples/plot_barycenter_lp_vs_entropic.py286
-rw-r--r--examples/plot_compute_emd.py102
-rw-r--r--examples/plot_convolutional_barycenter.py92
-rw-r--r--examples/plot_fgw.py173
-rw-r--r--examples/plot_free_support_barycenter.py69
-rw-r--r--examples/plot_gromov.py106
-rwxr-xr-xexamples/plot_gromov_barycenter.py248
-rw-r--r--examples/plot_optim_OTreg.py129
-rw-r--r--examples/plot_otda_classes.py150
-rw-r--r--examples/plot_otda_color_images.py165
-rw-r--r--examples/plot_otda_d2.py172
-rw-r--r--examples/plot_otda_linear_mapping.py144
-rw-r--r--examples/plot_otda_mapping.py125
-rw-r--r--examples/plot_otda_mapping_colors_images.py174
-rw-r--r--examples/plot_otda_semi_supervised.py148
-rw-r--r--examples/plot_stochastic.py208
26 files changed, 3736 insertions, 0 deletions
diff --git a/examples/README.txt b/examples/README.txt
new file mode 100644
index 0000000..b08d3f1
--- /dev/null
+++ b/examples/README.txt
@@ -0,0 +1,4 @@
+POT Examples
+============
+
+This is a gallery of all the POT example files.
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
new file mode 100644
index 0000000..f33e2a4
--- /dev/null
+++ b/examples/plot_OT_1D.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+"""
+====================
+1D optimal transport
+====================
+
+This example illustrates the computation of EMD and Sinkhorn transport plans
+and their visualization.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+
+##############################################################################
+# Generate data
+# -------------
+
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+#%% plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve EMD
+# ---------
+
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+##############################################################################
+# Solve Sinkhorn
+# --------------
+
+
+#%% Sinkhorn
+
+lambd = 1e-3
+Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
+
+pl.show()
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
new file mode 100644
index 0000000..b690751
--- /dev/null
+++ b/examples/plot_OT_1D_smooth.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+"""
+===========================
+1D smooth optimal transport
+===========================
+
+This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
+and their visualization.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+
+##############################################################################
+# Generate data
+# -------------
+
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+#%% plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve EMD
+# ---------
+
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+##############################################################################
+# Solve Sinkhorn
+# --------------
+
+
+#%% Sinkhorn
+
+lambd = 2e-3
+Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
+
+pl.show()
+
+##############################################################################
+# Solve Smooth OT
+# --------------
+
+
+#%% Smooth OT with KL regularization
+
+lambd = 2e-3
+Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')
+
+pl.figure(5, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')
+
+pl.show()
+
+
+#%% Smooth OT with KL regularization
+
+lambd = 1e-1
+Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
+
+pl.figure(6, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')
+
+pl.show()
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
new file mode 100644
index 0000000..63126ba
--- /dev/null
+++ b/examples/plot_OT_2D_samples.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+2D Optimal transport between empirical distributions
+====================================================
+
+Illustration of 2D optimal transport between discributions that are weighted
+sum of diracs. The OT matrix is plotted with the samples.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters and data generation
+
+n = 50 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+pl.figure(2)
+pl.imshow(M, interpolation='nearest')
+pl.title('Cost matrix M')
+
+##############################################################################
+# Compute EMD
+# -----------
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+pl.figure(3)
+pl.imshow(G0, interpolation='nearest')
+pl.title('OT matrix G0')
+
+pl.figure(4)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix with samples')
+
+
+##############################################################################
+# Compute Sinkhorn
+# ----------------
+
+#%% sinkhorn
+
+# reg term
+lambd = 1e-3
+
+Gs = ot.sinkhorn(a, b, M, lambd)
+
+pl.figure(5)
+pl.imshow(Gs, interpolation='nearest')
+pl.title('OT matrix sinkhorn')
+
+pl.figure(6)
+ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix Sinkhorn with samples')
+
+pl.show()
+
+
+##############################################################################
+# Emprirical Sinkhorn
+# ----------------
+
+#%% sinkhorn
+
+# reg term
+lambd = 1e-3
+
+Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
+
+pl.figure(7)
+pl.imshow(Ges, interpolation='nearest')
+pl.title('OT matrix empirical sinkhorn')
+
+pl.figure(8)
+ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix Sinkhorn from samples')
+
+pl.show()
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
new file mode 100644
index 0000000..37b429f
--- /dev/null
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+"""
+==========================================
+2D Optimal transport for different metrics
+==========================================
+
+2D OT on empirical distributio with different gound metric.
+
+Stole the figure idea from Fig. 1 and 2 in
+https://arxiv.org/pdf/1706.07650.pdf
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Dataset 1 : uniform sampling
+# ----------------------------
+
+n = 20 # nb samples
+xs = np.zeros((n, 2))
+xs[:, 0] = np.arange(n) + 1
+xs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...
+
+xt = np.zeros((n, 2))
+xt[:, 1] = np.arange(n) + 1
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M1 = ot.dist(xs, xt, metric='euclidean')
+M1 /= M1.max()
+
+# loss matrix
+M2 = ot.dist(xs, xt, metric='sqeuclidean')
+M2 /= M2.max()
+
+# loss matrix
+Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp /= Mp.max()
+
+# Data
+pl.figure(1, figsize=(7, 3))
+pl.clf()
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+pl.title('Source and target distributions')
+
+
+# Cost matrices
+pl.figure(2, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+pl.imshow(M1, interpolation='nearest')
+pl.title('Euclidean cost')
+
+pl.subplot(1, 3, 2)
+pl.imshow(M2, interpolation='nearest')
+pl.title('Squared Euclidean cost')
+
+pl.subplot(1, 3, 3)
+pl.imshow(Mp, interpolation='nearest')
+pl.title('Sqrt Euclidean cost')
+pl.tight_layout()
+
+##############################################################################
+# Dataset 1 : Plot OT Matrices
+# ----------------------------
+
+
+#%% EMD
+G1 = ot.emd(a, b, M1)
+G2 = ot.emd(a, b, M2)
+Gp = ot.emd(a, b, Mp)
+
+# OT matrices
+pl.figure(3, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT Euclidean')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT squared Euclidean')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT sqrt Euclidean')
+pl.tight_layout()
+
+pl.show()
+
+
+##############################################################################
+# Dataset 2 : Partial circle
+# --------------------------
+
+n = 50 # nb samples
+xtot = np.zeros((n + 1, 2))
+xtot[:, 0] = np.cos(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+xtot[:, 1] = np.sin(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+
+xs = xtot[:n, :]
+xt = xtot[1:, :]
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M1 = ot.dist(xs, xt, metric='euclidean')
+M1 /= M1.max()
+
+# loss matrix
+M2 = ot.dist(xs, xt, metric='sqeuclidean')
+M2 /= M2.max()
+
+# loss matrix
+Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp /= Mp.max()
+
+
+# Data
+pl.figure(4, figsize=(7, 3))
+pl.clf()
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+pl.title('Source and traget distributions')
+
+
+# Cost matrices
+pl.figure(5, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+pl.imshow(M1, interpolation='nearest')
+pl.title('Euclidean cost')
+
+pl.subplot(1, 3, 2)
+pl.imshow(M2, interpolation='nearest')
+pl.title('Squared Euclidean cost')
+
+pl.subplot(1, 3, 3)
+pl.imshow(Mp, interpolation='nearest')
+pl.title('Sqrt Euclidean cost')
+pl.tight_layout()
+
+##############################################################################
+# Dataset 2 : Plot OT Matrices
+# -----------------------------
+
+
+#%% EMD
+G1 = ot.emd(a, b, M1)
+G2 = ot.emd(a, b, M2)
+Gp = ot.emd(a, b, Mp)
+
+# OT matrices
+pl.figure(6, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT Euclidean')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT squared Euclidean')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT sqrt Euclidean')
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/plot_UOT_1D.py b/examples/plot_UOT_1D.py
new file mode 100644
index 0000000..2ea8b05
--- /dev/null
+++ b/examples/plot_UOT_1D.py
@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+"""
+===============================
+1D Unbalanced optimal transport
+===============================
+
+This example illustrates the computation of Unbalanced Optimal transport
+using a Kullback-Leibler relaxation.
+"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+
+##############################################################################
+# Generate data
+# -------------
+
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# make distributions unbalanced
+b *= 5.
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+# plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+
+##############################################################################
+# Solve Unbalanced Sinkhorn
+# --------------
+
+
+# Sinkhorn
+
+epsilon = 0.1 # entropy parameter
+alpha = 1. # Unbalanced KL relaxation parameter
+Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
+
+pl.show()
diff --git a/examples/plot_UOT_barycenter_1D.py b/examples/plot_UOT_barycenter_1D.py
new file mode 100644
index 0000000..c8d9d3b
--- /dev/null
+++ b/examples/plot_UOT_barycenter_1D.py
@@ -0,0 +1,164 @@
+# -*- coding: utf-8 -*-
+"""
+===========================================================
+1D Wasserstein barycenter demo for Unbalanced distributions
+===========================================================
+
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [10] for Unbalanced inputs.
+
+
+[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+from matplotlib.collections import PolyCollection
+
+##############################################################################
+# Generate data
+# -------------
+
+# parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# make unbalanced dists
+a2 *= 3.
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+# plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+# non weighted barycenter computation
+
+weight = 0.5 # 0<=weight<=1
+weights = np.array([1 - weight, weight])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+alpha = 1.
+
+bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+##############################################################################
+# Barycentric interpolation
+# -------------------------
+
+# barycenter interpolation
+
+n_weight = 11
+weight_list = np.linspace(0, 1, n_weight)
+
+
+B_l2 = np.zeros((n, n_weight))
+
+B_wass = np.copy(B_l2)
+
+for i in range(0, n_weight):
+ weight = weight_list[i]
+ weights = np.array([1 - weight, weight])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+
+# plot interpolation
+
+pl.figure(3)
+
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = weight_list
+for i, z in enumerate(zs):
+ ys = B_l2[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel(r'$\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with l2')
+pl.tight_layout()
+
+pl.figure(4)
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = weight_list
+for i, z in enumerate(zs):
+ ys = B_wass[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel(r'$\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with Wasserstein')
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/plot_WDA.py b/examples/plot_WDA.py
new file mode 100644
index 0000000..93cc237
--- /dev/null
+++ b/examples/plot_WDA.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Wasserstein Discriminant Analysis
+=================================
+
+This example illustrate the use of WDA as proposed in [11].
+
+
+[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+Wasserstein Discriminant Analysis.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+
+from ot.dr import wda, fda
+
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 1000 # nb samples in source and target datasets
+nz = 0.2
+
+# generate circle dataset
+t = np.random.rand(n) * 2 * np.pi
+ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xs = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
+
+t = np.random.rand(n) * 2 * np.pi
+yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xt = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
+
+nbnoise = 8
+
+xs = np.hstack((xs, np.random.randn(n, nbnoise)))
+xt = np.hstack((xt, np.random.randn(n, nbnoise)))
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot samples
+pl.figure(1, figsize=(6.4, 3.5))
+
+pl.subplot(1, 2, 1)
+pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
+pl.legend(loc=0)
+pl.title('Discriminant dimensions')
+
+pl.subplot(1, 2, 2)
+pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
+pl.legend(loc=0)
+pl.title('Other dimensions')
+pl.tight_layout()
+
+##############################################################################
+# Compute Fisher Discriminant Analysis
+# ------------------------------------
+
+#%% Compute FDA
+p = 2
+
+Pfda, projfda = fda(xs, ys, p)
+
+##############################################################################
+# Compute Wasserstein Discriminant Analysis
+# -----------------------------------------
+
+#%% Compute WDA
+p = 2
+reg = 1e0
+k = 10
+maxiter = 100
+
+Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
+
+
+##############################################################################
+# Plot 2D projections
+# -------------------
+
+#%% plot samples
+
+xsp = projfda(xs)
+xtp = projfda(xt)
+
+xspw = projwda(xs)
+xtpw = projwda(xt)
+
+pl.figure(2)
+
+pl.subplot(2, 2, 1)
+pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected training samples FDA')
+
+pl.subplot(2, 2, 2)
+pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected test samples FDA')
+
+pl.subplot(2, 2, 3)
+pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected training samples WDA')
+
+pl.subplot(2, 2, 4)
+pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected test samples WDA')
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/plot_barycenter_1D.py b/examples/plot_barycenter_1D.py
new file mode 100644
index 0000000..6864301
--- /dev/null
+++ b/examples/plot_barycenter_1D.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+1D Wasserstein barycenter demo
+==============================
+
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [3].
+
+
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+from matplotlib.collections import PolyCollection
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+#%% barycenter computation
+
+alpha = 0.2 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+##############################################################################
+# Barycentric interpolation
+# -------------------------
+
+#%% barycenter interpolation
+
+n_alpha = 11
+alpha_list = np.linspace(0, 1, n_alpha)
+
+
+B_l2 = np.zeros((n, n_alpha))
+
+B_wass = np.copy(B_l2)
+
+for i in range(0, n_alpha):
+ alpha = alpha_list[i]
+ weights = np.array([1 - alpha, alpha])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
+
+#%% plot interpolation
+
+pl.figure(3)
+
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = alpha_list
+for i, z in enumerate(zs):
+ ys = B_l2[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel('$\\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with l2')
+pl.tight_layout()
+
+pl.figure(4)
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = alpha_list
+for i, z in enumerate(zs):
+ ys = B_wass[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel('$\\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with Wasserstein')
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/plot_barycenter_fgw.py b/examples/plot_barycenter_fgw.py
new file mode 100644
index 0000000..77b0370
--- /dev/null
+++ b/examples/plot_barycenter_fgw.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Plot graphs' barycenter using FGW
+=================================
+
+This example illustrates the computation barycenter of labeled graphs using FGW
+
+Requires networkx >=2
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+#%% load libraries
+import numpy as np
+import matplotlib.pyplot as plt
+import networkx as nx
+import math
+from scipy.sparse.csgraph import shortest_path
+import matplotlib.colors as mcol
+from matplotlib import cm
+from ot.gromov import fgw_barycenters
+#%% Graph functions
+
+
+def find_thresh(C, inf=0.5, sup=3, step=10):
+ """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
+ Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
+ The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
+ and the original matrix.
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ inf : float
+ The beginning of the linesearch
+ sup : float
+ The end of the linesearch
+ step : integer
+ Number of thresholds tested
+ """
+ dist = []
+ search = np.linspace(inf, sup, step)
+ for thresh in search:
+ Cprime = sp_to_adjency(C, 0, thresh)
+ SC = shortest_path(Cprime, method='D')
+ SC[SC == float('inf')] = 100
+ dist.append(np.linalg.norm(SC - C))
+ return search[np.argmin(dist)], dist
+
+
+def sp_to_adjency(C, threshinf=0.2, threshsup=1.8):
+ """ Thresholds the structure matrix in order to compute an adjency matrix.
+ All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ threshinf : float
+ The minimum value of distance from which the new value is set to 1
+ threshsup : float
+ The maximum value of distance from which the new value is set to 1
+ Returns
+ -------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The threshold matrix. Each element is in {0,1}
+ """
+ H = np.zeros_like(C)
+ np.fill_diagonal(H, np.diagonal(C))
+ C = C - H
+ C = np.minimum(np.maximum(C, threshinf), threshsup)
+ C[C == threshsup] = 0
+ C[C != 0] = 1
+
+ return C
+
+
+def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):
+ """ Create a noisy circular graph
+ """
+ g = nx.Graph()
+ g.add_nodes_from(list(range(N)))
+ for i in range(N):
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
+ else:
+ g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))
+ g.add_edge(i, i + 1)
+ if structure_noise:
+ randomint = np.random.randint(0, p)
+ if randomint == 0:
+ if i <= N - 3:
+ g.add_edge(i, i + 2)
+ if i == N - 2:
+ g.add_edge(i, 0)
+ if i == N - 1:
+ g.add_edge(i, 1)
+ g.add_edge(N, 0)
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
+ else:
+ g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))
+ return g
+
+
+def graph_colors(nx_graph, vmin=0, vmax=7):
+ cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)
+ cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')
+ cpick.set_array([])
+ val_map = {}
+ for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():
+ val_map[k] = cpick.to_rgba(v)
+ colors = []
+ for node in nx_graph.nodes():
+ colors.append(val_map[node])
+ return colors
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% circular dataset
+# We build a dataset of noisy circular graphs.
+# Noise is added on the structures by random connections and on the features by gaussian noise.
+
+
+np.random.seed(30)
+X0 = []
+for k in range(9):
+ X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% Plot graphs
+
+plt.figure(figsize=(8, 10))
+for i in range(len(X0)):
+ plt.subplot(3, 3, i + 1)
+ g = X0[i]
+ pos = nx.kamada_kawai_layout(g)
+ nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)
+plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
+plt.show()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+# Features distances are the euclidean distances
+Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
+ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
+Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
+lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
+sizebary = 15 # we choose a barycenter with 15 nodes
+
+A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)
+
+##############################################################################
+# Plot Barycenter
+# -------------------------
+
+#%% Create the barycenter
+bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
+for i, v in enumerate(A.ravel()):
+ bary.add_node(i, attr_name=v)
+
+#%%
+pos = nx.kamada_kawai_layout(bary)
+nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)
+plt.suptitle('Barycenter', fontsize=20)
+plt.show()
diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/plot_barycenter_lp_vs_entropic.py
new file mode 100644
index 0000000..d7c72d0
--- /dev/null
+++ b/examples/plot_barycenter_lp_vs_entropic.py
@@ -0,0 +1,286 @@
+# -*- coding: utf-8 -*-
+"""
+=================================================================================
+1D Wasserstein barycenter comparison between exact LP and entropic regularization
+=================================================================================
+
+This example illustrates the computation of regularized Wasserstein Barycenter
+as proposed in [3] and exact LP barycenters using standard LP solver.
+
+It reproduces approximately Figure 3.1 and 3.2 from the following paper:
+Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational
+Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.
+
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+from matplotlib.collections import PolyCollection # noqa
+
+#import ot.lp.cvx as cvx
+
+##############################################################################
+# Gaussian Data
+# -------------
+
+#%% parameters
+
+problems = []
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+#%% barycenter computation
+
+alpha = 0.5 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+ot.tic()
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ot.toc()
+
+
+ot.tic()
+bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ot.toc()
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+##############################################################################
+# Stair Data
+# ----------
+
+#%% parameters
+
+a1 = 1.0 * (x > 10) * (x < 50)
+a2 = 1.0 * (x > 60) * (x < 80)
+
+a1 /= a1.sum()
+a2 /= a2.sum()
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+
+#%% barycenter computation
+
+alpha = 0.5 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+ot.tic()
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ot.toc()
+
+
+ot.tic()
+bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ot.toc()
+
+
+problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+
+##############################################################################
+# Dirac Data
+# ----------
+
+#%% parameters
+
+a1 = np.zeros(n)
+a2 = np.zeros(n)
+
+a1[10] = .25
+a1[20] = .5
+a1[30] = .25
+a2[80] = 1
+
+
+a1 /= a1.sum()
+a2 /= a2.sum()
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+
+#%% barycenter computation
+
+alpha = 0.5 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+ot.tic()
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ot.toc()
+
+
+ot.tic()
+bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ot.toc()
+
+
+problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+
+##############################################################################
+# Final figure
+# ------------
+#
+
+#%% plot
+
+nbm = len(problems)
+nbm2 = (nbm // 2)
+
+
+pl.figure(2, (20, 6))
+pl.clf()
+
+for i in range(nbm):
+
+ A = problems[i][0]
+ bary_l2 = problems[i][1][0]
+ bary_wass = problems[i][1][1]
+ bary_wass2 = problems[i][1][2]
+
+ pl.subplot(2, nbm, 1 + i)
+ for j in range(n_distributions):
+ pl.plot(x, A[:, j])
+ if i == nbm2:
+ pl.title('Distributions')
+ pl.xticks(())
+ pl.yticks(())
+
+ pl.subplot(2, nbm, 1 + i + nbm)
+
+ pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')
+ pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+ pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+ if i == nbm - 1:
+ pl.legend()
+ if i == nbm2:
+ pl.title('Barycenters')
+
+ pl.xticks(())
+ pl.yticks(())
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
new file mode 100644
index 0000000..7ed2b01
--- /dev/null
+++ b/examples/plot_compute_emd.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+"""
+=================
+Plot multiple EMD
+=================
+
+Shows how to compute multiple EMD and Sinkhorn with two differnt
+ground metrics and plot their values for diffeent distributions.
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+from ot.datasets import make_1D_gauss as gauss
+
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+n_target = 50 # nb target distributions
+
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+lst_m = np.linspace(20, 90, n_target)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+
+B = np.zeros((n, n_target))
+
+for i, m in enumerate(lst_m):
+ B[:, i] = gauss(n, m=m, s=5)
+
+# loss matrix and normalization
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
+M /= M.max()
+M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
+M2 /= M2.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1)
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'b', label='Source distribution')
+pl.title('Source distribution')
+pl.subplot(2, 1, 2)
+pl.plot(x, B, label='Target distributions')
+pl.title('Target distributions')
+pl.tight_layout()
+
+
+##############################################################################
+# Compute EMD for the different losses
+# ------------------------------------
+
+#%% Compute and plot distributions and loss matrix
+
+d_emd = ot.emd2(a, B, M) # direct computation of EMD
+d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
+
+
+pl.figure(2)
+pl.plot(d_emd, label='Euclidean EMD')
+pl.plot(d_emd2, label='Squared Euclidean EMD')
+pl.title('EMD distances')
+pl.legend()
+
+##############################################################################
+# Compute Sinkhorn for the different losses
+# -----------------------------------------
+
+#%%
+reg = 1e-2
+d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
+d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
+
+pl.figure(2)
+pl.clf()
+pl.plot(d_emd, label='Euclidean EMD')
+pl.plot(d_emd2, label='Squared Euclidean EMD')
+pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
+pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
+pl.title('EMD distances')
+pl.legend()
+
+pl.show()
diff --git a/examples/plot_convolutional_barycenter.py b/examples/plot_convolutional_barycenter.py
new file mode 100644
index 0000000..e74db04
--- /dev/null
+++ b/examples/plot_convolutional_barycenter.py
@@ -0,0 +1,92 @@
+
+#%%
+# -*- coding: utf-8 -*-
+"""
+============================================
+Convolutional Wasserstein Barycenter example
+============================================
+
+This example is designed to illustrate how the Convolutional Wasserstein Barycenter
+function of POT works.
+"""
+
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import pylab as pl
+import ot
+
+##############################################################################
+# Data preparation
+# ----------------
+#
+# The four distributions are constructed from 4 simple images
+
+
+f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
+f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
+f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
+f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
+
+A = []
+f1 = f1 / np.sum(f1)
+f2 = f2 / np.sum(f2)
+f3 = f3 / np.sum(f3)
+f4 = f4 / np.sum(f4)
+A.append(f1)
+A.append(f2)
+A.append(f3)
+A.append(f4)
+A = np.array(A)
+
+nb_images = 5
+
+# those are the four corners coordinates that will be interpolated by bilinear
+# interpolation
+v1 = np.array((1, 0, 0, 0))
+v2 = np.array((0, 1, 0, 0))
+v3 = np.array((0, 0, 1, 0))
+v4 = np.array((0, 0, 0, 1))
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+pl.figure(figsize=(10, 10))
+pl.title('Convolutional Wasserstein Barycenters in POT')
+cm = 'Blues'
+# regularization parameter
+reg = 0.004
+for i in range(nb_images):
+ for j in range(nb_images):
+ pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
+ tx = float(i) / (nb_images - 1)
+ ty = float(j) / (nb_images - 1)
+
+ # weights are constructed by bilinear interpolation
+ tmp1 = (1 - tx) * v1 + tx * v2
+ tmp2 = (1 - tx) * v3 + tx * v4
+ weights = (1 - ty) * tmp1 + ty * tmp2
+
+ if i == 0 and j == 0:
+ pl.imshow(f1, cmap=cm)
+ pl.axis('off')
+ elif i == 0 and j == (nb_images - 1):
+ pl.imshow(f3, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == 0:
+ pl.imshow(f2, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == (nb_images - 1):
+ pl.imshow(f4, cmap=cm)
+ pl.axis('off')
+ else:
+ # call to barycenter computation
+ pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
+ pl.axis('off')
+pl.show()
diff --git a/examples/plot_fgw.py b/examples/plot_fgw.py
new file mode 100644
index 0000000..43efc94
--- /dev/null
+++ b/examples/plot_fgw.py
@@ -0,0 +1,173 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+Plot Fused-gromov-Wasserstein
+==============================
+
+This example illustrates the computation of FGW for 1D measures[18].
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+import matplotlib.pyplot as pl
+import numpy as np
+import ot
+from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+
+##############################################################################
+# Generate data
+# ---------
+
+#%% parameters
+# We create two 1D random measures
+n = 20 # number of points in the first distribution
+n2 = 30 # number of points in the second distribution
+sig = 1 # std of first distribution
+sig2 = 0.1 # std of second distribution
+
+np.random.seed(0)
+
+phi = np.arange(n)[:, None]
+xs = phi + sig * np.random.randn(n, 1)
+ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)
+
+phi2 = np.arange(n2)[:, None]
+xt = phi2 + sig * np.random.randn(n2, 1)
+yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)
+yt = yt[::-1, :]
+
+p = ot.unif(n)
+q = ot.unif(n2)
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.close(10)
+pl.figure(10, (7, 7))
+
+pl.subplot(2, 1, 1)
+
+pl.scatter(ys, xs, c=phi, s=70)
+pl.ylabel('Feature value a', fontsize=20)
+pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)
+pl.xticks(())
+pl.yticks(())
+pl.subplot(2, 1, 2)
+pl.scatter(yt, xt, c=phi2, s=70)
+pl.xlabel('coordinates x/y', fontsize=25)
+pl.ylabel('Feature value b', fontsize=20)
+pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)
+pl.yticks(())
+pl.tight_layout()
+pl.show()
+
+##############################################################################
+# Create structure matrices and across-feature distance matrix
+# ---------
+
+#%% Structure matrices and across-features distance matrix
+C1 = ot.dist(xs)
+C2 = ot.dist(xt)
+M = ot.dist(ys, yt)
+w1 = ot.unif(C1.shape[0])
+w2 = ot.unif(C2.shape[0])
+Got = ot.emd([], [], M)
+
+##############################################################################
+# Plot matrices
+# ---------
+
+#%%
+cmap = 'Reds'
+pl.close(10)
+pl.figure(10, (5, 5))
+fs = 15
+l_x = [0, 5, 10, 15]
+l_y = [0, 5, 10, 15, 20, 25]
+gs = pl.GridSpec(5, 5)
+
+ax1 = pl.subplot(gs[3:, :2])
+
+pl.imshow(C1, cmap=cmap, interpolation='nearest')
+pl.title("$C_1$", fontsize=fs)
+pl.xlabel("$k$", fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+pl.xticks(l_x)
+pl.yticks(l_x)
+
+ax2 = pl.subplot(gs[:3, 2:])
+
+pl.imshow(C2, cmap=cmap, interpolation='nearest')
+pl.title("$C_2$", fontsize=fs)
+pl.ylabel("$l$", fontsize=fs)
+#pl.ylabel("$l$",fontsize=fs)
+pl.xticks(())
+pl.yticks(l_y)
+ax2.set_aspect('auto')
+
+ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)
+pl.imshow(M, cmap=cmap, interpolation='nearest')
+pl.yticks(l_x)
+pl.xticks(l_y)
+pl.ylabel("$i$", fontsize=fs)
+pl.title("$M_{AB}$", fontsize=fs)
+pl.xlabel("$j$", fontsize=fs)
+pl.tight_layout()
+ax3.set_aspect('auto')
+pl.show()
+
+##############################################################################
+# Compute FGW/GW
+# ---------
+
+#%% Computing FGW and GW
+alpha = 1e-3
+
+ot.tic()
+Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
+ot.toc()
+
+#%reload_ext WGW
+Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
+
+##############################################################################
+# Visualize transport matrices
+# ---------
+
+#%% visu OT matrix
+cmap = 'Blues'
+fs = 15
+pl.figure(2, (13, 5))
+pl.clf()
+pl.subplot(1, 3, 1)
+pl.imshow(Got, cmap=cmap, interpolation='nearest')
+#pl.xlabel("$y$",fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+pl.xticks(())
+
+pl.title('Wasserstein ($M$ only)')
+
+pl.subplot(1, 3, 2)
+pl.imshow(Gg, cmap=cmap, interpolation='nearest')
+pl.title('Gromov ($C_1,C_2$ only)')
+pl.xticks(())
+pl.subplot(1, 3, 3)
+pl.imshow(Gwg, cmap=cmap, interpolation='nearest')
+pl.title('FGW ($M+C_1,C_2$)')
+
+pl.xlabel("$j$", fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+
+pl.tight_layout()
+pl.show()
diff --git a/examples/plot_free_support_barycenter.py b/examples/plot_free_support_barycenter.py
new file mode 100644
index 0000000..64b89e4
--- /dev/null
+++ b/examples/plot_free_support_barycenter.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+2D free support Wasserstein barycenters of distributions
+====================================================
+
+Illustration of 2D Wasserstein barycenters if discributions that are weighted
+sum of diracs.
+
+"""
+
+# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+#%% parameters and data generation
+N = 3
+d = 2
+measures_locations = []
+measures_weights = []
+
+for i in range(N):
+
+ n_i = np.random.randint(low=1, high=20) # nb samples
+
+ mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+
+ A_i = np.random.rand(d, d)
+ cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+
+ x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
+ b_i = np.random.uniform(0., 1., (n_i,))
+ b_i = b_i / np.sum(b_i) # Dirac weights
+
+ measures_locations.append(x_i)
+ measures_weights.append(b_i)
+
+
+##############################################################################
+# Compute free support barycenter
+# -------------
+
+k = 10 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
+
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1)
+for (x_i, b_i) in zip(measures_locations, measures_weights):
+ color = np.random.randint(low=1, high=10 * N)
+ pl.scatter(x_i[:, 0], x_i[:, 1], s=b_i * 1000, label='input measure')
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc=0)
+pl.show()
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
new file mode 100644
index 0000000..deb2f86
--- /dev/null
+++ b/examples/plot_gromov.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+from mpl_toolkits.mplot3d import Axes3D # noqa
+import ot
+
+#############################################################################
+#
+# Sample two Gaussian distributions (2D and 3D)
+# ---------------------------------------------
+#
+# The Gromov-Wasserstein distance allows to compute distances with samples that
+# do not belong to the same metric space. For demonstration purpose, we sample
+# two Gaussian distributions in 2- and 3-dimensional spaces.
+
+
+n_samples = 30 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4, 4])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+
+#############################################################################
+#
+# Plotting the distributions
+# --------------------------
+
+
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+#############################################################################
+#
+# Compute distance kernels, normalize them and then display
+# ---------------------------------------------------------
+
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+C1 /= C1.max()
+C2 /= C2.max()
+
+pl.figure()
+pl.subplot(121)
+pl.imshow(C1)
+pl.subplot(122)
+pl.imshow(C2)
+pl.show()
+
+#############################################################################
+#
+# Compute Gromov-Wasserstein plans and distance
+# ---------------------------------------------
+
+p = ot.unif(n_samples)
+q = ot.unif(n_samples)
+
+gw0, log0 = ot.gromov.gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', verbose=True, log=True)
+
+gw, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
+
+
+print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
+print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(gw0, cmap='jet')
+pl.title('Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
+pl.imshow(gw, cmap='jet')
+pl.title('Entropic Gromov Wasserstein')
+
+pl.show()
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
new file mode 100755
index 0000000..58fc51a
--- /dev/null
+++ b/examples/plot_gromov_barycenter.py
@@ -0,0 +1,248 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================
+Gromov-Wasserstein Barycenter example
+=====================================
+
+This example is designed to show how to use the Gromov-Wasserstein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import scipy as sp
+
+import scipy.ndimage as spi
+import matplotlib.pylab as pl
+from sklearn import manifold
+from sklearn.decomposition import PCA
+
+import ot
+
+##############################################################################
+# Smacof MDS
+# ----------
+#
+# This function allows to find an embedding of points given a dissimilarity matrix
+# that will be given by the output of the algorithm
+
+
+def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
+ """
+ Returns an interpolated point cloud following the dissimilarity matrix C
+ using SMACOF multidimensional scaling (MDS) in specific dimensionned
+ target space
+
+ Parameters
+ ----------
+ C : ndarray, shape (ns, ns)
+ dissimilarity matrix
+ dim : int
+ dimension of the targeted space
+ max_iter : int
+ Maximum number of iterations of the SMACOF algorithm for a single run
+ eps : float
+ relative tolerance w.r.t stress to declare converge
+
+ Returns
+ -------
+ npos : ndarray, shape (R, dim)
+ Embedded coordinates of the interpolated point cloud (defined with
+ one isometry)
+ """
+
+ rng = np.random.RandomState(seed=3)
+
+ mds = manifold.MDS(
+ dim,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity='precomputed',
+ n_init=1)
+ pos = mds.fit(C).embedding_
+
+ nmds = manifold.MDS(
+ 2,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity="precomputed",
+ random_state=rng,
+ n_init=1)
+ npos = nmds.fit_transform(C, init=pos)
+
+ return npos
+
+
+##############################################################################
+# Data preparation
+# ----------------
+#
+# The four distributions are constructed from 4 simple images
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
+cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
+star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
+
+shapes = [square, cross, triangle, star]
+
+S = 4
+xs = [[] for i in range(S)]
+
+
+for nb in range(4):
+ for i in range(8):
+ for j in range(8):
+ if shapes[nb][i, j] < 0.95:
+ xs[nb].append([j, 8 - i])
+
+xs = np.array([np.array(xs[0]), np.array(xs[1]),
+ np.array(xs[2]), np.array(xs[3])])
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+
+ns = [len(xs[s]) for s in range(S)]
+n_samples = 30
+
+"""Compute all distances matrices for the four shapes"""
+Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
+Cs = [cs / cs.max() for cs in Cs]
+
+ps = [ot.unif(ns[s]) for s in range(S)]
+p = ot.unif(n_samples)
+
+
+lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
+
+Ct01 = [0 for i in range(2)]
+for i in range(2):
+ Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
+ [ps[0], ps[1]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+Ct02 = [0 for i in range(2)]
+for i in range(2):
+ Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
+ [ps[0], ps[2]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+Ct13 = [0 for i in range(2)]
+for i in range(2):
+ Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
+ [ps[1], ps[3]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+Ct23 = [0 for i in range(2)]
+for i in range(2):
+ Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
+ [ps[2], ps[3]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+
+##############################################################################
+# Visualization
+# -------------
+#
+# The PCA helps in getting consistency between the rotations
+
+
+clf = PCA(n_components=2)
+npos = [0, 0, 0, 0]
+npos = [smacof_mds(Cs[s], 2) for s in range(S)]
+
+npost01 = [0, 0]
+npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
+npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
+
+npost02 = [0, 0]
+npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
+npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
+
+npost13 = [0, 0]
+npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
+npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
+
+npost23 = [0, 0]
+npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
+npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
+
+
+fig = pl.figure(figsize=(10, 10))
+
+ax1 = pl.subplot2grid((4, 4), (0, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
+
+ax2 = pl.subplot2grid((4, 4), (0, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
+
+ax3 = pl.subplot2grid((4, 4), (0, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
+
+ax4 = pl.subplot2grid((4, 4), (0, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
+
+ax5 = pl.subplot2grid((4, 4), (1, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
+
+ax6 = pl.subplot2grid((4, 4), (1, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
+
+ax7 = pl.subplot2grid((4, 4), (2, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
+
+ax8 = pl.subplot2grid((4, 4), (2, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
+
+ax9 = pl.subplot2grid((4, 4), (3, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
+
+ax10 = pl.subplot2grid((4, 4), (3, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
+
+ax11 = pl.subplot2grid((4, 4), (3, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
+
+ax12 = pl.subplot2grid((4, 4), (3, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py
new file mode 100644
index 0000000..2c58def
--- /dev/null
+++ b/examples/plot_optim_OTreg.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+"""
+==================================
+Regularized OT with generic solver
+==================================
+
+Illustrates the use of the generic solver for regularized OT with
+user-designed regularization term. It uses Conditional gradient as in [6] and
+generalized Conditional Gradient as proposed in [5][7].
+
+
+[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for
+Domain Adaptation, in IEEE Transactions on Pattern Analysis and Machine
+Intelligence , vol.PP, no.99, pp.1-1.
+
+[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport. SIAM Journal on Imaging Sciences,
+7(3), 1853-1882.
+
+[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized
+conditional gradient: analysis of convergence and applications.
+arXiv preprint arXiv:1510.06567.
+
+
+
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+b = ot.datasets.make_1D_gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+##############################################################################
+# Solve EMD
+# ---------
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+##############################################################################
+# Solve EMD with Frobenius norm regularization
+# --------------------------------------------
+
+#%% Example with Frobenius norm regularization
+
+
+def f(G):
+ return 0.5 * np.sum(G**2)
+
+
+def df(G):
+ return G
+
+
+reg = 1e-1
+
+Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
+
+pl.figure(3)
+ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')
+
+##############################################################################
+# Solve EMD with entropic regularization
+# --------------------------------------
+
+#%% Example with entropic regularization
+
+
+def f(G):
+ return np.sum(G * np.log(G))
+
+
+def df(G):
+ return np.log(G) + 1.
+
+
+reg = 1e-3
+
+Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')
+
+##############################################################################
+# Solve EMD with Frobenius norm + entropic regularization
+# -------------------------------------------------------
+
+#%% Example with Frobenius norm + entropic regularization with gcg
+
+
+def f(G):
+ return 0.5 * np.sum(G**2)
+
+
+def df(G):
+ return G
+
+
+reg1 = 1e-3
+reg2 = 1e-1
+
+Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)
+
+pl.figure(5, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')
+pl.show()
diff --git a/examples/plot_otda_classes.py b/examples/plot_otda_classes.py
new file mode 100644
index 0000000..c311fbd
--- /dev/null
+++ b/examples/plot_otda_classes.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+"""
+========================
+OT for domain adaptation
+========================
+
+This example introduces a domain adaptation in a 2D setting and the 4 OTDA
+approaches currently supported in POT.
+
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+
+n_source_samples = 150
+n_target_samples = 150
+
+Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples)
+Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples)
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMD Transport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport with Group lasso regularization
+ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+# Sinkhorn Transport with Group lasso regularization l1l2
+ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,
+ verbose=True)
+ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+# transport source samples onto target samples
+transp_Xs_emd = ot_emd.transform(Xs=Xs)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)
+
+
+##############################################################################
+# Fig 1 : plots source and target samples
+# ---------------------------------------
+
+pl.figure(1, figsize=(10, 5))
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 2 : plot optimal couplings and transported samples
+# ------------------------------------------------------
+
+param_img = {'interpolation': 'nearest'}
+
+pl.figure(2, figsize=(15, 8))
+pl.subplot(2, 4, 1)
+pl.imshow(ot_emd.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nEMDTransport')
+
+pl.subplot(2, 4, 2)
+pl.imshow(ot_sinkhorn.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornTransport')
+
+pl.subplot(2, 4, 3)
+pl.imshow(ot_lpl1.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+pl.subplot(2, 4, 4)
+pl.imshow(ot_l1l2.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornL1l2Transport')
+
+pl.subplot(2, 4, 5)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc="lower left")
+
+pl.subplot(2, 4, 6)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornTransport')
+
+pl.subplot(2, 4, 7)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornLpl1Transport')
+
+pl.subplot(2, 4, 8)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornL1l2Transport')
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/plot_otda_color_images.py b/examples/plot_otda_color_images.py
new file mode 100644
index 0000000..62383a2
--- /dev/null
+++ b/examples/plot_otda_color_images.py
@@ -0,0 +1,165 @@
+# -*- coding: utf-8 -*-
+"""
+=============================
+OT for image color adaptation
+=============================
+
+This example presents a way of transferring colors between two images
+with Optimal Transport as introduced in [6]
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport.
+SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+from scipy import ndimage
+import matplotlib.pylab as pl
+import ot
+
+
+r = np.random.RandomState(42)
+
+
+def im2mat(I):
+ """Converts an image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+##############################################################################
+# Generate data
+# -------------
+
+# Loading images
+I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+X1 = im2mat(I1)
+X2 = im2mat(I2)
+
+# training samples
+nb = 1000
+idx1 = r.randint(X1.shape[0], size=(nb,))
+idx2 = r.randint(X2.shape[0], size=(nb,))
+
+Xs = X1[idx1, :]
+Xt = X2[idx2, :]
+
+
+##############################################################################
+# Plot original image
+# -------------------
+
+pl.figure(1, figsize=(6.4, 3))
+
+pl.subplot(1, 2, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Image 2')
+
+
+##############################################################################
+# Scatter plot of colors
+# ----------------------
+
+pl.figure(2, figsize=(6.4, 3))
+
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 2')
+pl.tight_layout()
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMDTransport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# SinkhornTransport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# prediction between images (using out of sample prediction as in [6])
+transp_Xs_emd = ot_emd.transform(Xs=X1)
+transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
+
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)
+transp_Xt_sinkhorn = ot_sinkhorn.inverse_transform(Xt=X2)
+
+I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
+I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
+
+I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
+
+
+##############################################################################
+# Plot new images
+# ---------------
+
+pl.figure(3, figsize=(8, 4))
+
+pl.subplot(2, 3, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Image 1')
+
+pl.subplot(2, 3, 2)
+pl.imshow(I1t)
+pl.axis('off')
+pl.title('Image 1 Adapt')
+
+pl.subplot(2, 3, 3)
+pl.imshow(I1te)
+pl.axis('off')
+pl.title('Image 1 Adapt (reg)')
+
+pl.subplot(2, 3, 4)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Image 2')
+
+pl.subplot(2, 3, 5)
+pl.imshow(I2t)
+pl.axis('off')
+pl.title('Image 2 Adapt')
+
+pl.subplot(2, 3, 6)
+pl.imshow(I2te)
+pl.axis('off')
+pl.title('Image 2 Adapt (reg)')
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/plot_otda_d2.py b/examples/plot_otda_d2.py
new file mode 100644
index 0000000..cf22c2f
--- /dev/null
+++ b/examples/plot_otda_d2.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+"""
+===================================================
+OT for domain adaptation on empirical distributions
+===================================================
+
+This example introduces a domain adaptation in a 2D setting. It explicits
+the problem of domain adaptation and introduces some optimal transport
+approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# generate data
+# -------------
+
+n_samples_source = 150
+n_samples_target = 150
+
+Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source)
+Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target)
+
+# Cost matrix
+M = ot.dist(Xs, Xt, metric='sqeuclidean')
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMD Transport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport with Group lasso regularization
+ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+# transport source samples onto target samples
+transp_Xs_emd = ot_emd.transform(Xs=Xs)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+
+
+##############################################################################
+# Fig 1 : plots source and target samples + matrix of pairwise distance
+# ---------------------------------------------------------------------
+
+pl.figure(1, figsize=(10, 10))
+pl.subplot(2, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+
+pl.subplot(2, 2, 3)
+pl.imshow(M, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Matrix of pairwise distances')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 2 : plots optimal couplings for the different methods
+# ---------------------------------------------------------
+pl.figure(2, figsize=(10, 6))
+
+pl.subplot(2, 3, 1)
+pl.imshow(ot_emd.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nEMDTransport')
+
+pl.subplot(2, 3, 2)
+pl.imshow(ot_sinkhorn.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornTransport')
+
+pl.subplot(2, 3, 3)
+pl.imshow(ot_lpl1.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+pl.subplot(2, 3, 4)
+ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1])
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.title('Main coupling coefficients\nEMDTransport')
+
+pl.subplot(2, 3, 5)
+ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1])
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.title('Main coupling coefficients\nSinkhornTransport')
+
+pl.subplot(2, 3, 6)
+ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1])
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.title('Main coupling coefficients\nSinkhornLpl1Transport')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 3 : plot transported samples
+# --------------------------------
+
+# display transported samples
+pl.figure(4, figsize=(10, 4))
+pl.subplot(1, 3, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc=0)
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 3, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornTransport')
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 3, 3)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornLpl1Transport')
+pl.xticks([])
+pl.yticks([])
+
+pl.tight_layout()
+pl.show()
diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py
new file mode 100644
index 0000000..c65bd4f
--- /dev/null
+++ b/examples/plot_otda_linear_mapping.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+============================
+Linear OT mapping estimation
+============================
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+n = 1000
+d = 2
+sigma = .1
+
+# source samples
+angles = np.random.rand(n, 1) * 2 * np.pi
+xs = np.concatenate((np.sin(angles), np.cos(angles)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xs[:n // 2, 1] += 2
+
+
+# target samples
+anglet = np.random.rand(n, 1) * 2 * np.pi
+xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xt[:n // 2, 1] += 2
+
+
+A = np.array([[1.5, .7], [.7, 1.5]])
+b = np.array([[4, 2]])
+xt = xt.dot(A) + b
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1, (5, 5))
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+
+
+##############################################################################
+# Estimate linear mapping and transport
+# -------------------------------------
+
+Ae, be = ot.da.OT_mapping_linear(xs, xt)
+
+xst = xs.dot(Ae) + be
+
+
+##############################################################################
+# Plot transported samples
+# ------------------------
+
+pl.figure(1, (5, 5))
+pl.clf()
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+pl.plot(xst[:, 0], xst[:, 1], '+')
+
+pl.show()
+
+##############################################################################
+# Load image data
+# ---------------
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+# Loading images
+I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+
+X1 = im2mat(I1)
+X2 = im2mat(I2)
+
+##############################################################################
+# Estimate mapping and adapt
+# ----------------------------
+
+mapping = ot.da.LinearTransport()
+
+mapping.fit(Xs=X1, Xt=X2)
+
+
+xst = mapping.transform(Xs=X1)
+xts = mapping.inverse_transform(Xt=X2)
+
+I1t = minmax(mat2im(xst, I1.shape))
+I2t = minmax(mat2im(xts, I2.shape))
+
+# %%
+
+
+##############################################################################
+# Plot transformed images
+# -----------------------
+
+pl.figure(2, figsize=(10, 7))
+
+pl.subplot(2, 2, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Im. 1')
+
+pl.subplot(2, 2, 2)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Im. 2')
+
+pl.subplot(2, 2, 3)
+pl.imshow(I1t)
+pl.axis('off')
+pl.title('Mapping Im. 1')
+
+pl.subplot(2, 2, 4)
+pl.imshow(I2t)
+pl.axis('off')
+pl.title('Inverse mapping Im. 2')
diff --git a/examples/plot_otda_mapping.py b/examples/plot_otda_mapping.py
new file mode 100644
index 0000000..5880adf
--- /dev/null
+++ b/examples/plot_otda_mapping.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+"""
+===========================================
+OT mapping estimation for domain adaptation
+===========================================
+
+This example presents how to use MappingTransport to estimate at the same
+time both the coupling transport and approximate the transport map with either
+a linear or a kernelized mapping as introduced in [8].
+
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+
+n_source_samples = 100
+n_target_samples = 100
+theta = 2 * np.pi / 20
+noise_level = 0.1
+
+Xs, ys = ot.datasets.make_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+Xs_new, _ = ot.datasets.make_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+Xt, yt = ot.datasets.make_data_classif(
+ 'gaussrot', n_target_samples, theta=theta, nz=noise_level)
+
+# one of the target mode changes its variance (no linear mapping)
+Xt[yt == 2] *= 3
+Xt = Xt + 4
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1, (10, 5))
+pl.clf()
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# MappingTransport with linear kernel
+ot_mapping_linear = ot.da.MappingTransport(
+ kernel="linear", mu=1e0, eta=1e-8, bias=True,
+ max_iter=20, verbose=True)
+
+ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+# for original source samples, transform applies barycentric mapping
+transp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)
+
+# for out of source samples, transform applies the linear mapping
+transp_Xs_linear_new = ot_mapping_linear.transform(Xs=Xs_new)
+
+
+# MappingTransport with gaussian kernel
+ot_mapping_gaussian = ot.da.MappingTransport(
+ kernel="gaussian", eta=1e-5, mu=1e-1, bias=True, sigma=1,
+ max_iter=10, verbose=True)
+ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+# for original source samples, transform applies barycentric mapping
+transp_Xs_gaussian = ot_mapping_gaussian.transform(Xs=Xs)
+
+# for out of source samples, transform applies the gaussian mapping
+transp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new)
+
+
+##############################################################################
+# Plot transported samples
+# ------------------------
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 2, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_linear[:, 0], transp_Xs_linear[:, 1], c=ys, marker='+',
+ label='Mapped source samples')
+pl.title("Bary. mapping (linear)")
+pl.legend(loc=0)
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_linear_new[:, 0], transp_Xs_linear_new[:, 1],
+ c=ys, marker='+', label='Learned mapping')
+pl.title("Estim. mapping (linear)")
+
+pl.subplot(2, 2, 3)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_gaussian[:, 0], transp_Xs_gaussian[:, 1], c=ys,
+ marker='+', label='barycentric mapping')
+pl.title("Bary. mapping (kernel)")
+
+pl.subplot(2, 2, 4)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_gaussian_new[:, 0], transp_Xs_gaussian_new[:, 1], c=ys,
+ marker='+', label='Learned mapping')
+pl.title("Estim. mapping (kernel)")
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/plot_otda_mapping_colors_images.py b/examples/plot_otda_mapping_colors_images.py
new file mode 100644
index 0000000..a20eca8
--- /dev/null
+++ b/examples/plot_otda_mapping_colors_images.py
@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================================
+OT for image color adaptation with mapping estimation
+=====================================================
+
+OT for domain adaptation with image color adaptation [6] with mapping
+estimation [8].
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized
+ discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3),
+ 1853-1882.
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
+ discrete optimal transport", Neural Information Processing Systems (NIPS),
+ 2016.
+
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+from scipy import ndimage
+import matplotlib.pylab as pl
+import ot
+
+r = np.random.RandomState(42)
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+##############################################################################
+# Generate data
+# -------------
+
+# Loading images
+I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+
+X1 = im2mat(I1)
+X2 = im2mat(I2)
+
+# training samples
+nb = 1000
+idx1 = r.randint(X1.shape[0], size=(nb,))
+idx2 = r.randint(X2.shape[0], size=(nb,))
+
+Xs = X1[idx1, :]
+Xt = X2[idx2, :]
+
+
+##############################################################################
+# Domain adaptation for pixel distribution transfer
+# -------------------------------------------------
+
+# EMDTransport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+transp_Xs_emd = ot_emd.transform(Xs=X1)
+Image_emd = minmax(mat2im(transp_Xs_emd, I1.shape))
+
+# SinkhornTransport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)
+Image_sinkhorn = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+
+ot_mapping_linear = ot.da.MappingTransport(
+ mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True)
+ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+X1tl = ot_mapping_linear.transform(Xs=X1)
+Image_mapping_linear = minmax(mat2im(X1tl, I1.shape))
+
+ot_mapping_gaussian = ot.da.MappingTransport(
+ mu=1e0, eta=1e-2, sigma=1, bias=False, max_iter=10, verbose=True)
+ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+X1tn = ot_mapping_gaussian.transform(Xs=X1) # use the estimated mapping
+Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
+
+
+##############################################################################
+# Plot original images
+# --------------------
+
+pl.figure(1, figsize=(6.4, 3))
+pl.subplot(1, 2, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Image 2')
+pl.tight_layout()
+
+
+##############################################################################
+# Plot pixel values distribution
+# ------------------------------
+
+pl.figure(2, figsize=(6.4, 5))
+
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 2')
+pl.tight_layout()
+
+
+##############################################################################
+# Plot transformed images
+# -----------------------
+
+pl.figure(2, figsize=(10, 5))
+
+pl.subplot(2, 3, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Im. 1')
+
+pl.subplot(2, 3, 4)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Im. 2')
+
+pl.subplot(2, 3, 2)
+pl.imshow(Image_emd)
+pl.axis('off')
+pl.title('EmdTransport')
+
+pl.subplot(2, 3, 5)
+pl.imshow(Image_sinkhorn)
+pl.axis('off')
+pl.title('SinkhornTransport')
+
+pl.subplot(2, 3, 3)
+pl.imshow(Image_mapping_linear)
+pl.axis('off')
+pl.title('MappingTransport (linear)')
+
+pl.subplot(2, 3, 6)
+pl.imshow(Image_mapping_gaussian)
+pl.axis('off')
+pl.title('MappingTransport (gaussian)')
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/plot_otda_semi_supervised.py b/examples/plot_otda_semi_supervised.py
new file mode 100644
index 0000000..8a67720
--- /dev/null
+++ b/examples/plot_otda_semi_supervised.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+"""
+============================================
+OTDA unsupervised vs semi-supervised setting
+============================================
+
+This example introduces a semi supervised domain adaptation in a 2D setting.
+It explicits the problem of semi supervised domain adaptation and introduces
+some optimal transport approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+
+n_samples_source = 150
+n_samples_target = 150
+
+Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source)
+Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target)
+
+
+##############################################################################
+# Transport source samples onto target samples
+# --------------------------------------------
+
+
+# unsupervised domain adaptation
+ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
+transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
+
+# semi-supervised domain adaptation
+ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
+transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
+
+# semi supervised DA uses available labaled target samples to modify the cost
+# matrix involved in the OT problem. The cost of transporting a source sample
+# of class A onto a target sample of class B != A is set to infinite, or a
+# very large value
+
+# note that in the present case we consider that all the target samples are
+# labeled. For daily applications, some target sample might not have labels,
+# in this case the element of yt corresponding to these samples should be
+# filled with -1.
+
+# Warning: we recall that -1 cannot be used as a class label
+
+
+##############################################################################
+# Fig 1 : plots source and target samples + matrix of pairwise distance
+# ---------------------------------------------------------------------
+
+pl.figure(1, figsize=(10, 10))
+pl.subplot(2, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+
+pl.subplot(2, 2, 3)
+pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - unsupervised DA')
+
+pl.subplot(2, 2, 4)
+pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - semisupervised DA')
+
+pl.tight_layout()
+
+# the optimal coupling in the semi-supervised DA case will exhibit " shape
+# similar" to the cost matrix, (block diagonal matrix)
+
+
+##############################################################################
+# Fig 2 : plots optimal couplings for the different methods
+# ---------------------------------------------------------
+
+pl.figure(2, figsize=(8, 4))
+
+pl.subplot(1, 2, 1)
+pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nUnsupervised DA')
+
+pl.subplot(1, 2, 2)
+pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSemi-supervised DA')
+
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 3 : plot transported samples
+# --------------------------------
+
+# display transported samples
+pl.figure(4, figsize=(8, 4))
+pl.subplot(1, 2, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc=0)
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornTransport')
+pl.xticks([])
+pl.yticks([])
+
+pl.tight_layout()
+pl.show()
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py
new file mode 100644
index 0000000..742f8d9
--- /dev/null
+++ b/examples/plot_stochastic.py
@@ -0,0 +1,208 @@
+"""
+==========================
+Stochastic examples
+==========================
+
+This example is designed to show how to use the stochatic optimization
+algorithms for descrete and semicontinous measures from the POT library.
+
+"""
+
+# Author: Kilian Fatras <kilian.fatras@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+import ot
+import ot.plot
+
+
+#############################################################################
+# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
+#############################################################################
+#############################################################################
+# DISCRETE CASE:
+#
+# Sample two discrete measures for the discrete case
+# ---------------------------------------------
+#
+# Define 2 discrete measures a and b, the points where are defined the source
+# and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 1000
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "SAG" method to find the transportation matrix in the discrete case
+# ---------------------------------------------
+#
+# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the
+# results.
+
+method = "SAG"
+sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax)
+print(sag_pi)
+
+#############################################################################
+# SEMICONTINOUS CASE:
+#
+# Sample one general measure a, one discrete measures b for the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define one general measure a, one discrete measures b, the points where
+# are defined the source and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 1000
+log = True
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "ASGD" method to find the transportation matrix in the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the
+# results.
+
+method = "ASGD"
+asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax, log=log)
+print(log_asgd['alpha'], log_asgd['beta'])
+print(asgd_pi)
+
+#############################################################################
+#
+# Compare the results with the Sinkhorn algorithm
+# ---------------------------------------------
+#
+# Call the Sinkhorn algorithm from POT
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+print(sinkhorn_pi)
+
+
+##############################################################################
+# PLOT TRANSPORTATION MATRIX
+##############################################################################
+
+##############################################################################
+# Plot SAG results
+# ----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
+pl.show()
+
+
+##############################################################################
+# Plot ASGD results
+# -----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
+pl.show()
+
+
+##############################################################################
+# Plot Sinkhorn results
+# ---------------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()
+
+
+#############################################################################
+# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
+#############################################################################
+#############################################################################
+# SEMICONTINOUS CASE:
+#
+# Sample one general measure a, one discrete measures b for the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define one general measure a, one discrete measures b, the points where
+# are defined the source and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 100000
+lr = 0.1
+batch_size = 3
+log = True
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "SGD" dual method to find the transportation matrix in the
+# semicontinous case
+# ---------------------------------------------
+#
+# Call ot.solve_dual_entropic and plot the results.
+
+sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
+ batch_size, numItermax,
+ lr, log=log)
+print(log_sgd['alpha'], log_sgd['beta'])
+print(sgd_dual_pi)
+
+#############################################################################
+#
+# Compare the results with the Sinkhorn algorithm
+# ---------------------------------------------
+#
+# Call the Sinkhorn algorithm from POT
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+print(sinkhorn_pi)
+
+##############################################################################
+# Plot SGD results
+# -----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
+pl.show()
+
+
+##############################################################################
+# Plot Sinkhorn results
+# ---------------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()