summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/README.txt8
-rw-r--r--examples/barycenters/README.txt4
-rw-r--r--examples/barycenters/plot_barycenter_1D.py (renamed from examples/plot_barycenter_1D.py)2
-rw-r--r--examples/barycenters/plot_barycenter_lp_vs_entropic.py (renamed from examples/plot_barycenter_lp_vs_entropic.py)2
-rw-r--r--examples/barycenters/plot_convolutional_barycenter.py (renamed from examples/plot_convolutional_barycenter.py)8
-rw-r--r--examples/barycenters/plot_free_support_barycenter.py (renamed from examples/plot_free_support_barycenter.py)6
-rw-r--r--examples/domain-adaptation/README.txt5
-rw-r--r--examples/domain-adaptation/plot_otda_classes.py (renamed from examples/plot_otda_classes.py)1
-rw-r--r--examples/domain-adaptation/plot_otda_color_images.py (renamed from examples/plot_otda_color_images.py)7
-rw-r--r--examples/domain-adaptation/plot_otda_d2.py (renamed from examples/plot_otda_d2.py)4
-rw-r--r--examples/domain-adaptation/plot_otda_jcpot.py171
-rw-r--r--examples/domain-adaptation/plot_otda_laplacian.py127
-rw-r--r--examples/domain-adaptation/plot_otda_linear_mapping.py (renamed from examples/plot_otda_linear_mapping.py)6
-rw-r--r--examples/domain-adaptation/plot_otda_mapping.py (renamed from examples/plot_otda_mapping.py)6
-rw-r--r--examples/domain-adaptation/plot_otda_mapping_colors_images.py (renamed from examples/plot_otda_mapping_colors_images.py)14
-rw-r--r--examples/domain-adaptation/plot_otda_semi_supervised.py (renamed from examples/plot_otda_semi_supervised.py)2
-rw-r--r--examples/gromov/README.txt4
-rw-r--r--examples/gromov/plot_barycenter_fgw.py (renamed from examples/plot_barycenter_fgw.py)11
-rw-r--r--examples/gromov/plot_fgw.py (renamed from examples/plot_fgw.py)16
-rw-r--r--examples/gromov/plot_gromov.py (renamed from examples/plot_gromov.py)0
-rwxr-xr-xexamples/gromov/plot_gromov_barycenter.py (renamed from examples/plot_gromov_barycenter.py)9
-rw-r--r--examples/others/README.txt5
-rw-r--r--examples/others/plot_WDA.py (renamed from examples/plot_WDA.py)10
-rw-r--r--examples/plot_OT_1D.py1
-rw-r--r--examples/plot_OT_1D_smooth.py2
-rw-r--r--examples/plot_OT_2D_samples.py2
-rw-r--r--examples/plot_OT_L1_vs_L2.py2
-rw-r--r--examples/plot_compute_emd.py6
-rw-r--r--examples/plot_optim_OTreg.py7
-rw-r--r--examples/plot_screenkhorn_1D.py71
-rw-r--r--examples/plot_stochastic.py101
-rw-r--r--examples/unbalanced-partial/README.txt3
-rw-r--r--examples/unbalanced-partial/plot_UOT_1D.py (renamed from examples/plot_UOT_1D.py)0
-rw-r--r--examples/unbalanced-partial/plot_UOT_barycenter_1D.py (renamed from examples/plot_UOT_barycenter_1D.py)6
-rwxr-xr-xexamples/unbalanced-partial/plot_partial_wass_and_gromov.py165
35 files changed, 684 insertions, 110 deletions
diff --git a/examples/README.txt b/examples/README.txt
index b08d3f1..69a9f84 100644
--- a/examples/README.txt
+++ b/examples/README.txt
@@ -1,4 +1,8 @@
-POT Examples
-============
+Examples gallery
+================
This is a gallery of all the POT example files.
+
+
+OT and regularized OT
+--------------------- \ No newline at end of file
diff --git a/examples/barycenters/README.txt b/examples/barycenters/README.txt
new file mode 100644
index 0000000..8461f7f
--- /dev/null
+++ b/examples/barycenters/README.txt
@@ -0,0 +1,4 @@
+
+
+Wasserstein barycenters
+----------------------- \ No newline at end of file
diff --git a/examples/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 6864301..63dc460 100644
--- a/examples/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -18,6 +18,8 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
import numpy as np
import matplotlib.pylab as pl
import ot
diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
index d7c72d0..57a6bac 100644
--- a/examples/plot_barycenter_lp_vs_entropic.py
+++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py
@@ -21,6 +21,8 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
import numpy as np
import matplotlib.pylab as pl
import ot
diff --git a/examples/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py
index e74db04..cbcd4a1 100644
--- a/examples/plot_convolutional_barycenter.py
+++ b/examples/barycenters/plot_convolutional_barycenter.py
@@ -26,10 +26,10 @@ import ot
# The four distributions are constructed from 4 simple images
-f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
-f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
-f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
-f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
+f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2]
+f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2]
+f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2]
+f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2]
A = []
f1 = f1 / np.sum(f1)
diff --git a/examples/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py
index 64b89e4..27ddc8e 100644
--- a/examples/plot_free_support_barycenter.py
+++ b/examples/barycenters/plot_free_support_barycenter.py
@@ -4,7 +4,7 @@
2D free support Wasserstein barycenters of distributions
====================================================
-Illustration of 2D Wasserstein barycenters if discributions that are weighted
+Illustration of 2D Wasserstein barycenters if distributions are weighted
sum of diracs.
"""
@@ -21,7 +21,7 @@ import ot
##############################################################################
# Generate data
# -------------
-#%% parameters and data generation
+
N = 3
d = 2
measures_locations = []
@@ -46,7 +46,7 @@ for i in range(N):
##############################################################################
# Compute free support barycenter
-# -------------
+# -------------------------------
k = 10 # number of Diracs of the barycenter
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
diff --git a/examples/domain-adaptation/README.txt b/examples/domain-adaptation/README.txt
new file mode 100644
index 0000000..81dd8d2
--- /dev/null
+++ b/examples/domain-adaptation/README.txt
@@ -0,0 +1,5 @@
+
+
+
+Domain adaptation examples
+-------------------------- \ No newline at end of file
diff --git a/examples/plot_otda_classes.py b/examples/domain-adaptation/plot_otda_classes.py
index c311fbd..f028022 100644
--- a/examples/plot_otda_classes.py
+++ b/examples/domain-adaptation/plot_otda_classes.py
@@ -17,7 +17,6 @@ approaches currently supported in POT.
import matplotlib.pylab as pl
import ot
-
##############################################################################
# Generate data
# -------------
diff --git a/examples/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py
index 62383a2..929365e 100644
--- a/examples/plot_otda_color_images.py
+++ b/examples/domain-adaptation/plot_otda_color_images.py
@@ -17,8 +17,9 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
-from scipy import ndimage
import matplotlib.pylab as pl
import ot
@@ -45,8 +46,8 @@ def minmax(I):
# -------------
# Loading images
-I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
X1 = im2mat(I1)
X2 = im2mat(I2)
diff --git a/examples/plot_otda_d2.py b/examples/domain-adaptation/plot_otda_d2.py
index cf22c2f..d8b2a93 100644
--- a/examples/plot_otda_d2.py
+++ b/examples/domain-adaptation/plot_otda_d2.py
@@ -18,12 +18,14 @@ of what the transport methods are doing.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import matplotlib.pylab as pl
import ot
import ot.plot
##############################################################################
-# generate data
+# Generate data
# -------------
n_samples_source = 150
diff --git a/examples/domain-adaptation/plot_otda_jcpot.py b/examples/domain-adaptation/plot_otda_jcpot.py
new file mode 100644
index 0000000..c495690
--- /dev/null
+++ b/examples/domain-adaptation/plot_otda_jcpot.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+"""
+========================
+OT for multi-source target shift
+========================
+
+This example introduces a target shift problem with two 2D source and 1 target domain.
+
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
+#
+# License: MIT License
+
+import pylab as pl
+import numpy as np
+import ot
+from ot.datasets import make_data_classif
+
+##############################################################################
+# Generate data
+# -------------
+n = 50
+sigma = 0.3
+np.random.seed(1985)
+
+p1 = .2
+dec1 = [0, 2]
+
+p2 = .9
+dec2 = [0, -2]
+
+pt = .4
+dect = [4, 0]
+
+xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1)
+xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2)
+xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect)
+
+all_Xr = [xs1, xs2]
+all_Yr = [ys1, ys2]
+# %%
+
+da = 1.5
+
+
+def plot_ax(dec, name):
+ pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)
+ pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)
+ pl.text(dec[0] - .5, dec[1] + 2, name)
+
+
+##############################################################################
+# Fig 1 : plots source and target samples
+# ---------------------------------------
+
+pl.figure(1)
+pl.clf()
+plot_ax(dec1, 'Source 1')
+plot_ax(dec2, 'Source 2')
+plot_ax(dect, 'Target')
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9,
+ label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1))
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9,
+ label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2))
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9,
+ label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt))
+pl.title('Data')
+
+pl.legend()
+pl.axis('equal')
+pl.axis('off')
+
+##############################################################################
+# Instantiate Sinkhorn transport algorithm and fit them for all source domains
+# ----------------------------------------------------------------------------
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean')
+
+
+def print_G(G, xs, ys, xt):
+ for i in range(G.shape[0]):
+ for j in range(G.shape[1]):
+ if G[i, j] > 5e-4:
+ if ys[i]:
+ c = 'b'
+ else:
+ c = 'r'
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2)
+
+
+##############################################################################
+# Fig 2 : plot optimal couplings and transported samples
+# ------------------------------------------------------
+pl.figure(2)
+pl.clf()
+plot_ax(dec1, 'Source 1')
+plot_ax(dec2, 'Source 2')
+plot_ax(dect, 'Target')
+print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt)
+print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt)
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
+
+pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
+pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
+
+pl.title('Independent OT')
+
+pl.legend()
+pl.axis('equal')
+pl.axis('off')
+
+##############################################################################
+# Instantiate JCPOT adaptation algorithm and fit it
+# ----------------------------------------------------------------------------
+otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
+otda.fit(all_Xr, all_Yr, xt)
+
+ws1 = otda.proportions_.dot(otda.log_['D2'][0])
+ws2 = otda.proportions_.dot(otda.log_['D2'][1])
+
+pl.figure(3)
+pl.clf()
+plot_ax(dec1, 'Source 1')
+plot_ax(dec2, 'Source 2')
+plot_ax(dect, 'Target')
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
+
+pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
+pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
+
+pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1]))
+
+pl.legend()
+pl.axis('equal')
+pl.axis('off')
+
+##############################################################################
+# Run oracle transport algorithm with known proportions
+# ----------------------------------------------------------------------------
+h_res = np.array([1 - pt, pt])
+
+ws1 = h_res.dot(otda.log_['D2'][0])
+ws2 = h_res.dot(otda.log_['D2'][1])
+
+pl.figure(4)
+pl.clf()
+plot_ax(dec1, 'Source 1')
+plot_ax(dec2, 'Source 2')
+plot_ax(dect, 'Target')
+print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
+
+pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
+pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
+
+pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1]))
+
+pl.legend()
+pl.axis('equal')
+pl.axis('off')
+pl.show()
diff --git a/examples/domain-adaptation/plot_otda_laplacian.py b/examples/domain-adaptation/plot_otda_laplacian.py
new file mode 100644
index 0000000..67c8f67
--- /dev/null
+++ b/examples/domain-adaptation/plot_otda_laplacian.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+"""
+======================================================
+OT with Laplacian regularization for domain adaptation
+======================================================
+
+This example introduces a domain adaptation in a 2D setting and OTDA
+approach with Laplacian regularization.
+
+"""
+
+# Authors: Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
+
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+n_source_samples = 150
+n_target_samples = 150
+
+Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples)
+Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples)
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMD Transport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.01)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# EMD Transport with Laplacian regularization
+ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1)
+ot_emd_laplace.fit(Xs=Xs, Xt=Xt)
+
+# transport source samples onto target samples
+transp_Xs_emd = ot_emd.transform(Xs=Xs)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs)
+
+##############################################################################
+# Fig 1 : plots source and target samples
+# ---------------------------------------
+
+pl.figure(1, figsize=(10, 5))
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 2 : plot optimal couplings and transported samples
+# ------------------------------------------------------
+
+param_img = {'interpolation': 'nearest'}
+
+pl.figure(2, figsize=(15, 8))
+pl.subplot(2, 3, 1)
+pl.imshow(ot_emd.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nEMDTransport')
+
+pl.figure(2, figsize=(15, 8))
+pl.subplot(2, 3, 2)
+pl.imshow(ot_sinkhorn.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornTransport')
+
+pl.subplot(2, 3, 3)
+pl.imshow(ot_emd_laplace.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nEMDLaplaceTransport')
+
+pl.subplot(2, 3, 4)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc="lower left")
+
+pl.subplot(2, 3, 5)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornTransport')
+
+pl.subplot(2, 3, 6)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nEMDLaplaceTransport')
+pl.tight_layout()
+
+pl.show()
diff --git a/examples/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py
index c65bd4f..dbf16b8 100644
--- a/examples/plot_otda_linear_mapping.py
+++ b/examples/domain-adaptation/plot_otda_linear_mapping.py
@@ -12,6 +12,8 @@ Linear OT mapping estimation
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
import pylab as pl
import ot
@@ -92,8 +94,8 @@ def minmax(I):
# Loading images
-I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
X1 = im2mat(I1)
diff --git a/examples/plot_otda_mapping.py b/examples/domain-adaptation/plot_otda_mapping.py
index 5880adf..d21d3c9 100644
--- a/examples/plot_otda_mapping.py
+++ b/examples/domain-adaptation/plot_otda_mapping.py
@@ -9,8 +9,8 @@ time both the coupling transport and approximate the transport map with either
a linear or a kernelized mapping as introduced in [8].
[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
- "Mapping estimation for discrete optimal transport",
- Neural Information Processing Systems (NIPS), 2016.
+"Mapping estimation for discrete optimal transport",
+Neural Information Processing Systems (NIPS), 2016.
"""
# Authors: Remi Flamary <remi.flamary@unice.fr>
@@ -18,6 +18,8 @@ a linear or a kernelized mapping as introduced in [8].
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
import matplotlib.pylab as pl
import ot
diff --git a/examples/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
index a20eca8..ee5c8b0 100644
--- a/examples/plot_otda_mapping_colors_images.py
+++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py
@@ -8,11 +8,10 @@ OT for domain adaptation with image color adaptation [6] with mapping
estimation [8].
[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized
- discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3),
- 1853-1882.
+discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+
[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
- discrete optimal transport", Neural Information Processing Systems (NIPS),
- 2016.
+discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
"""
@@ -21,8 +20,9 @@ estimation [8].
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
import numpy as np
-from scipy import ndimage
import matplotlib.pylab as pl
import ot
@@ -48,8 +48,8 @@ def minmax(I):
# -------------
# Loading images
-I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
-I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256
X1 = im2mat(I1)
diff --git a/examples/plot_otda_semi_supervised.py b/examples/domain-adaptation/plot_otda_semi_supervised.py
index 8a67720..478c3b8 100644
--- a/examples/plot_otda_semi_supervised.py
+++ b/examples/domain-adaptation/plot_otda_semi_supervised.py
@@ -18,6 +18,8 @@ of what the transport methods are doing.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
import matplotlib.pylab as pl
import ot
diff --git a/examples/gromov/README.txt b/examples/gromov/README.txt
new file mode 100644
index 0000000..9cc9c64
--- /dev/null
+++ b/examples/gromov/README.txt
@@ -0,0 +1,4 @@
+
+
+Gromov and Fused-Gromov-Wasserstein
+----------------------------------- \ No newline at end of file
diff --git a/examples/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py
index 77b0370..3f81765 100644
--- a/examples/plot_barycenter_fgw.py
+++ b/examples/gromov/plot_barycenter_fgw.py
@@ -4,14 +4,15 @@
Plot graphs' barycenter using FGW
=================================
-This example illustrates the computation barycenter of labeled graphs using FGW
+This example illustrates the computation barycenter of labeled graphs using
+FGW [18].
Requires networkx >=2
-.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
- and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
+[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+and Courty Nicolas
+"Optimal Transport for structured data with application on graphs"
+International Conference on Machine Learning (ICML). 2019.
"""
diff --git a/examples/plot_fgw.py b/examples/gromov/plot_fgw.py
index 43efc94..97fe619 100644
--- a/examples/plot_fgw.py
+++ b/examples/gromov/plot_fgw.py
@@ -4,12 +4,12 @@
Plot Fused-gromov-Wasserstein
==============================
-This example illustrates the computation of FGW for 1D measures[18].
+This example illustrates the computation of FGW for 1D measures [18].
-.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
- and Courty Nicolas
- "Optimal Transport for structured data with application on graphs"
- International Conference on Machine Learning (ICML). 2019.
+[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain
+and Courty Nicolas
+"Optimal Transport for structured data with application on graphs"
+International Conference on Machine Learning (ICML). 2019.
"""
@@ -17,6 +17,8 @@ This example illustrates the computation of FGW for 1D measures[18].
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
import matplotlib.pyplot as pl
import numpy as np
import ot
@@ -60,14 +62,14 @@ pl.subplot(2, 1, 1)
pl.scatter(ys, xs, c=phi, s=70)
pl.ylabel('Feature value a', fontsize=20)
-pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)
+pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, y=1)
pl.xticks(())
pl.yticks(())
pl.subplot(2, 1, 2)
pl.scatter(yt, xt, c=phi2, s=70)
pl.xlabel('coordinates x/y', fontsize=25)
pl.ylabel('Feature value b', fontsize=20)
-pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)
+pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, y=1)
pl.yticks(())
pl.tight_layout()
pl.show()
diff --git a/examples/plot_gromov.py b/examples/gromov/plot_gromov.py
index deb2f86..deb2f86 100644
--- a/examples/plot_gromov.py
+++ b/examples/gromov/plot_gromov.py
diff --git a/examples/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py
index 58fc51a..f6f031a 100755
--- a/examples/plot_gromov_barycenter.py
+++ b/examples/gromov/plot_gromov_barycenter.py
@@ -17,7 +17,6 @@ computation in POT.
import numpy as np
import scipy as sp
-import scipy.ndimage as spi
import matplotlib.pylab as pl
from sklearn import manifold
from sklearn.decomposition import PCA
@@ -90,10 +89,10 @@ def im2mat(I):
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
-square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
-cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
-triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
-star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
+square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2]
+cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2]
+triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2]
+star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2]
shapes = [square, cross, triangle, star]
diff --git a/examples/others/README.txt b/examples/others/README.txt
new file mode 100644
index 0000000..df4c697
--- /dev/null
+++ b/examples/others/README.txt
@@ -0,0 +1,5 @@
+
+
+
+Other OT problems
+----------------- \ No newline at end of file
diff --git a/examples/plot_WDA.py b/examples/others/plot_WDA.py
index 93cc237..bdfa57d 100644
--- a/examples/plot_WDA.py
+++ b/examples/others/plot_WDA.py
@@ -16,6 +16,8 @@ Wasserstein Discriminant Analysis.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
import matplotlib.pylab as pl
@@ -31,6 +33,8 @@ from ot.dr import wda, fda
n = 1000 # nb samples in source and target datasets
nz = 0.2
+np.random.seed(1)
+
# generate circle dataset
t = np.random.rand(n) * 2 * np.pi
ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
@@ -86,7 +90,11 @@ reg = 1e0
k = 10
maxiter = 100
-Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
+P0 = np.random.randn(xs.shape[1], p)
+
+P0 /= np.sqrt(np.sum(P0**2, 0, keepdims=True))
+
+Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter, P0=P0)
##############################################################################
diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py
index f33e2a4..15ead96 100644
--- a/examples/plot_OT_1D.py
+++ b/examples/plot_OT_1D.py
@@ -12,6 +12,7 @@ and their visualization.
# Author: Remi Flamary <remi.flamary@unice.fr>
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
import numpy as np
import matplotlib.pylab as pl
diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py
index b690751..75cd295 100644
--- a/examples/plot_OT_1D_smooth.py
+++ b/examples/plot_OT_1D_smooth.py
@@ -13,6 +13,8 @@ and their visualization.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 6
+
import numpy as np
import matplotlib.pylab as pl
import ot
diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py
index 63126ba..1544e82 100644
--- a/examples/plot_OT_2D_samples.py
+++ b/examples/plot_OT_2D_samples.py
@@ -14,6 +14,8 @@ sum of diracs. The OT matrix is plotted with the samples.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 4
+
import numpy as np
import matplotlib.pylab as pl
import ot
diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py
index 37b429f..60353ab 100644
--- a/examples/plot_OT_L1_vs_L2.py
+++ b/examples/plot_OT_L1_vs_L2.py
@@ -16,6 +16,8 @@ https://arxiv.org/pdf/1706.07650.pdf
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
import numpy as np
import matplotlib.pylab as pl
import ot
diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py
index 7ed2b01..527a847 100644
--- a/examples/plot_compute_emd.py
+++ b/examples/plot_compute_emd.py
@@ -4,8 +4,8 @@
Plot multiple EMD
=================
-Shows how to compute multiple EMD and Sinkhorn with two differnt
-ground metrics and plot their values for diffeent distributions.
+Shows how to compute multiple EMD and Sinkhorn with two different
+ground metrics and plot their values for different distributions.
"""
@@ -14,6 +14,8 @@ ground metrics and plot their values for diffeent distributions.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 3
+
import numpy as np
import matplotlib.pylab as pl
import ot
diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py
index 2c58def..5eb15bd 100644
--- a/examples/plot_optim_OTreg.py
+++ b/examples/plot_optim_OTreg.py
@@ -6,7 +6,7 @@ Regularized OT with generic solver
Illustrates the use of the generic solver for regularized OT with
user-designed regularization term. It uses Conditional gradient as in [6] and
-generalized Conditional Gradient as proposed in [5][7].
+generalized Conditional Gradient as proposed in [5,7].
[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for
@@ -14,8 +14,8 @@ Domain Adaptation, in IEEE Transactions on Pattern Analysis and Machine
Intelligence , vol.PP, no.99, pp.1-1.
[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
-Regularized discrete optimal transport. SIAM Journal on Imaging Sciences,
-7(3), 1853-1882.
+Regularized discrete optimal transport. SIAM Journal on Imaging
+Sciences, 7(3), 1853-1882.
[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized
conditional gradient: analysis of convergence and applications.
@@ -24,6 +24,7 @@ arXiv preprint arXiv:1510.06567.
"""
+# sphinx_gallery_thumbnail_number = 4
import numpy as np
import matplotlib.pylab as pl
diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py
new file mode 100644
index 0000000..785642a
--- /dev/null
+++ b/examples/plot_screenkhorn_1D.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+"""
+===============================
+1D Screened optimal transport
+===============================
+
+This example illustrates the computation of Screenkhorn [26].
+
+[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019).
+Screening Sinkhorn Algorithm for Regularized Optimal Transport,
+Advances in Neural Information Processing Systems 33 (NeurIPS).
+"""
+
+# Author: Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+from ot.bregman import screenkhorn
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+# plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve Screenkhorn
+# -----------------------
+
+# Screenkhorn
+lambd = 2e-03 # entropy parameter
+ns_budget = 30 # budget number of points to be keeped in the source distribution
+nt_budget = 30 # budget number of points to be keeped in the target distribution
+
+G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True)
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn')
+pl.show()
diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py
index 742f8d9..3a1ef31 100644
--- a/examples/plot_stochastic.py
+++ b/examples/plot_stochastic.py
@@ -1,10 +1,18 @@
"""
-==========================
+===================
Stochastic examples
-==========================
+===================
This example is designed to show how to use the stochatic optimization
-algorithms for descrete and semicontinous measures from the POT library.
+algorithms for discrete and semi-continuous measures from the POT library.
+
+[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F.
+Stochastic Optimization for Large-scale Optimal Transport.
+Advances in Neural Information Processing Systems (2016).
+
+[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. &
+Blondel, M. Large-scale Optimal Transport and Mapping Estimation.
+International Conference on Learning Representation (2018)
"""
@@ -19,16 +27,14 @@ import ot.plot
#############################################################################
-# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
-#############################################################################
-#############################################################################
-# DISCRETE CASE:
+# Compute the Transportation Matrix for the Semi-Dual Problem
+# -----------------------------------------------------------
#
-# Sample two discrete measures for the discrete case
-# ---------------------------------------------
+# Discrete case
+# `````````````
#
-# Define 2 discrete measures a and b, the points where are defined the source
-# and the target measures and finally the cost matrix c.
+# Sample two discrete measures for the discrete case and compute their cost
+# matrix c.
n_source = 7
n_target = 4
@@ -44,12 +50,7 @@ Y_target = rng.randn(n_target, 2)
M = ot.dist(X_source, Y_target)
#############################################################################
-#
# Call the "SAG" method to find the transportation matrix in the discrete case
-# ---------------------------------------------
-#
-# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the
-# results.
method = "SAG"
sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
@@ -57,14 +58,12 @@ sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
print(sag_pi)
#############################################################################
-# SEMICONTINOUS CASE:
+# Semi-Continuous Case
+# ````````````````````
#
# Sample one general measure a, one discrete measures b for the semicontinous
-# case
-# ---------------------------------------------
-#
-# Define one general measure a, one discrete measures b, the points where
-# are defined the source and the target measures and finally the cost matrix c.
+# case, the points where source and target measures are defined and compute the
+# cost matrix.
n_source = 7
n_target = 4
@@ -81,13 +80,8 @@ Y_target = rng.randn(n_target, 2)
M = ot.dist(X_source, Y_target)
#############################################################################
-#
# Call the "ASGD" method to find the transportation matrix in the semicontinous
-# case
-# ---------------------------------------------
-#
-# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the
-# results.
+# case.
method = "ASGD"
asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
@@ -96,23 +90,17 @@ print(log_asgd['alpha'], log_asgd['beta'])
print(asgd_pi)
#############################################################################
-#
# Compare the results with the Sinkhorn algorithm
-# ---------------------------------------------
-#
-# Call the Sinkhorn algorithm from POT
sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
##############################################################################
-# PLOT TRANSPORTATION MATRIX
-##############################################################################
-
-##############################################################################
-# Plot SAG results
-# ----------------
+# Plot Transportation Matrices
+# ````````````````````````````
+#
+# For SAG
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
@@ -120,8 +108,7 @@ pl.show()
##############################################################################
-# Plot ASGD results
-# -----------------
+# For ASGD
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
@@ -129,8 +116,7 @@ pl.show()
##############################################################################
-# Plot Sinkhorn results
-# ---------------------
+# For Sinkhorn
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
@@ -138,17 +124,14 @@ pl.show()
#############################################################################
-# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
-#############################################################################
-#############################################################################
-# SEMICONTINOUS CASE:
+# Compute the Transportation Matrix for the Dual Problem
+# ------------------------------------------------------
#
-# Sample one general measure a, one discrete measures b for the semicontinous
-# case
-# ---------------------------------------------
+# Semi-continuous case
+# ````````````````````
#
-# Define one general measure a, one discrete measures b, the points where
-# are defined the source and the target measures and finally the cost matrix c.
+# Sample one general measure a, one discrete measures b for the semi-continuous
+# case and compute the cost matrix c.
n_source = 7
n_target = 4
@@ -169,10 +152,7 @@ M = ot.dist(X_source, Y_target)
#############################################################################
#
# Call the "SGD" dual method to find the transportation matrix in the
-# semicontinous case
-# ---------------------------------------------
-#
-# Call ot.solve_dual_entropic and plot the results.
+# semi-continuous case
sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
batch_size, numItermax,
@@ -183,7 +163,7 @@ print(sgd_dual_pi)
#############################################################################
#
# Compare the results with the Sinkhorn algorithm
-# ---------------------------------------------
+# ```````````````````````````````````````````````
#
# Call the Sinkhorn algorithm from POT
@@ -191,8 +171,10 @@ sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
print(sinkhorn_pi)
##############################################################################
-# Plot SGD results
-# -----------------
+# Plot Transportation Matrices
+# ````````````````````````````
+#
+# For SGD
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
@@ -200,8 +182,7 @@ pl.show()
##############################################################################
-# Plot Sinkhorn results
-# ---------------------
+# For Sinkhorn
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
diff --git a/examples/unbalanced-partial/README.txt b/examples/unbalanced-partial/README.txt
new file mode 100644
index 0000000..2f404f0
--- /dev/null
+++ b/examples/unbalanced-partial/README.txt
@@ -0,0 +1,3 @@
+
+Unbalanced and Partial OT
+------------------------- \ No newline at end of file
diff --git a/examples/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py
index 2ea8b05..2ea8b05 100644
--- a/examples/plot_UOT_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_1D.py
diff --git a/examples/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
index c8d9d3b..931798b 100644
--- a/examples/plot_UOT_barycenter_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
@@ -16,6 +16,8 @@ as proposed in [10] for Unbalanced inputs.
#
# License: MIT License
+# sphinx_gallery_thumbnail_number = 2
+
import numpy as np
import matplotlib.pylab as pl
import ot
@@ -77,7 +79,7 @@ bary_l2 = A.dot(weights)
reg = 1e-3
alpha = 1.
-bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights)
pl.figure(2)
pl.clf()
@@ -111,7 +113,7 @@ for i in range(0, n_weight):
weight = weight_list[i]
weights = np.array([1 - weight, weight])
B_l2[:, i] = A.dot(weights)
- B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+ B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights)
# plot interpolation
diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
new file mode 100755
index 0000000..0c5cbf9
--- /dev/null
+++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py
@@ -0,0 +1,165 @@
+# -*- coding: utf-8 -*-
+"""
+==================================================
+Partial Wasserstein and Gromov-Wasserstein example
+==================================================
+
+This example is designed to show how to use the Partial (Gromov-)Wassertsein
+distance computation in POT.
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@irisa.fr>
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+#############################################################################
+#
+# Sample two 2D Gaussian distributions and plot them
+# --------------------------------------------------
+#
+# For demonstration purpose, we sample two Gaussian distributions in 2-d
+# spaces and add some random noise.
+
+
+n_samples = 20 # nb samples (gaussian)
+n_noise = 20 # nb of samples (noise)
+
+mu = np.array([0, 0])
+cov = np.array([[1, 0], [0, 2]])
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+M = sp.spatial.distance.cdist(xs, xt)
+
+fig = pl.figure()
+ax1 = fig.add_subplot(131)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(132)
+ax2.scatter(xt[:, 0], xt[:, 1], color='r')
+ax3 = fig.add_subplot(133)
+ax3.imshow(M)
+pl.show()
+
+#############################################################################
+#
+# Compute partial Wasserstein plans and distance
+# ----------------------------------------------
+
+p = ot.unif(n_samples + n_noise)
+q = ot.unif(n_samples + n_noise)
+
+w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=0.5, log=True)
+w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5,
+ log=True)
+
+print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist']))
+print('Entropic partial Wasserstein distance (m = 0.5): ' +
+ str(log['partial_w_dist']))
+
+pl.figure(1, (10, 5))
+pl.subplot(1, 2, 1)
+pl.imshow(w0, cmap='jet')
+pl.title('Partial Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(w, cmap='jet')
+pl.title('Entropic partial Wasserstein')
+pl.show()
+
+
+#############################################################################
+#
+# Sample one 2D and 3D Gaussian distributions and plot them
+# ---------------------------------------------------------
+#
+# The Gromov-Wasserstein distance allows to compute distances with samples that
+# do not belong to the same metric space. For demonstration purpose, we sample
+# two Gaussian distributions in 2- and 3-dimensional spaces.
+
+n_samples = 20 # nb samples
+n_noise = 10 # nb of samples (noise)
+
+p = ot.unif(n_samples + n_noise)
+q = ot.unif(n_samples + n_noise)
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([0, 0, 0])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+
+#############################################################################
+#
+# Compute partial Gromov-Wasserstein plans and distance
+# -----------------------------------------------------
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+# transport 100% of the mass
+print('-----m = 1')
+m = 1
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+
+print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
+print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))
+
+pl.figure(1, (10, 5))
+pl.title("mass to be transported m = 1")
+pl.subplot(1, 2, 1)
+pl.imshow(res0, cmap='jet')
+pl.title('Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(res, cmap='jet')
+pl.title('Entropic Wasserstein')
+pl.show()
+
+# transport 2/3 of the mass
+print('-----m = 2/3')
+m = 2 / 3
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+
+print('Partial Wasserstein distance (m = 2/3): ' +
+ str(log0['partial_gw_dist']))
+print('Entropic partial Wasserstein distance (m = 2/3): ' +
+ str(log['partial_gw_dist']))
+
+pl.figure(1, (10, 5))
+pl.title("mass to be transported m = 2/3")
+pl.subplot(1, 2, 1)
+pl.imshow(res0, cmap='jet')
+pl.title('Partial Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(res, cmap='jet')
+pl.title('Entropic partial Wasserstein')
+pl.show()