diff options
Diffstat (limited to 'examples')
35 files changed, 684 insertions, 110 deletions
diff --git a/examples/README.txt b/examples/README.txt index b08d3f1..69a9f84 100644 --- a/examples/README.txt +++ b/examples/README.txt @@ -1,4 +1,8 @@ -POT Examples -============ +Examples gallery +================ This is a gallery of all the POT example files. + + +OT and regularized OT +---------------------
\ No newline at end of file diff --git a/examples/barycenters/README.txt b/examples/barycenters/README.txt new file mode 100644 index 0000000..8461f7f --- /dev/null +++ b/examples/barycenters/README.txt @@ -0,0 +1,4 @@ + + +Wasserstein barycenters +-----------------------
\ No newline at end of file diff --git a/examples/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py index 6864301..63dc460 100644 --- a/examples/plot_barycenter_1D.py +++ b/examples/barycenters/plot_barycenter_1D.py @@ -18,6 +18,8 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import numpy as np import matplotlib.pylab as pl import ot diff --git a/examples/plot_barycenter_lp_vs_entropic.py b/examples/barycenters/plot_barycenter_lp_vs_entropic.py index d7c72d0..57a6bac 100644 --- a/examples/plot_barycenter_lp_vs_entropic.py +++ b/examples/barycenters/plot_barycenter_lp_vs_entropic.py @@ -21,6 +21,8 @@ SIAM Journal on Scientific Computing, 37(2), A1111-A1138. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import numpy as np import matplotlib.pylab as pl import ot diff --git a/examples/plot_convolutional_barycenter.py b/examples/barycenters/plot_convolutional_barycenter.py index e74db04..cbcd4a1 100644 --- a/examples/plot_convolutional_barycenter.py +++ b/examples/barycenters/plot_convolutional_barycenter.py @@ -26,10 +26,10 @@ import ot # The four distributions are constructed from 4 simple images -f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2] -f2 = 1 - pl.imread('../data/duck.png')[:, :, 2] -f3 = 1 - pl.imread('../data/heart.png')[:, :, 2] -f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2] +f1 = 1 - pl.imread('../../data/redcross.png')[:, :, 2] +f2 = 1 - pl.imread('../../data/duck.png')[:, :, 2] +f3 = 1 - pl.imread('../../data/heart.png')[:, :, 2] +f4 = 1 - pl.imread('../../data/tooth.png')[:, :, 2] A = [] f1 = f1 / np.sum(f1) diff --git a/examples/plot_free_support_barycenter.py b/examples/barycenters/plot_free_support_barycenter.py index 64b89e4..27ddc8e 100644 --- a/examples/plot_free_support_barycenter.py +++ b/examples/barycenters/plot_free_support_barycenter.py @@ -4,7 +4,7 @@ 2D free support Wasserstein barycenters of distributions ==================================================== -Illustration of 2D Wasserstein barycenters if discributions that are weighted +Illustration of 2D Wasserstein barycenters if distributions are weighted sum of diracs. """ @@ -21,7 +21,7 @@ import ot ############################################################################## # Generate data # ------------- -#%% parameters and data generation + N = 3 d = 2 measures_locations = [] @@ -46,7 +46,7 @@ for i in range(N): ############################################################################## # Compute free support barycenter -# ------------- +# ------------------------------- k = 10 # number of Diracs of the barycenter X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations diff --git a/examples/domain-adaptation/README.txt b/examples/domain-adaptation/README.txt new file mode 100644 index 0000000..81dd8d2 --- /dev/null +++ b/examples/domain-adaptation/README.txt @@ -0,0 +1,5 @@ + + + +Domain adaptation examples +--------------------------
\ No newline at end of file diff --git a/examples/plot_otda_classes.py b/examples/domain-adaptation/plot_otda_classes.py index c311fbd..f028022 100644 --- a/examples/plot_otda_classes.py +++ b/examples/domain-adaptation/plot_otda_classes.py @@ -17,7 +17,6 @@ approaches currently supported in POT. import matplotlib.pylab as pl import ot - ############################################################################## # Generate data # ------------- diff --git a/examples/plot_otda_color_images.py b/examples/domain-adaptation/plot_otda_color_images.py index 62383a2..929365e 100644 --- a/examples/plot_otda_color_images.py +++ b/examples/domain-adaptation/plot_otda_color_images.py @@ -17,8 +17,9 @@ SIAM Journal on Imaging Sciences, 7(3), 1853-1882. # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import numpy as np -from scipy import ndimage import matplotlib.pylab as pl import ot @@ -45,8 +46,8 @@ def minmax(I): # ------------- # Loading images -I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 +I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 +I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 X1 = im2mat(I1) X2 = im2mat(I2) diff --git a/examples/plot_otda_d2.py b/examples/domain-adaptation/plot_otda_d2.py index cf22c2f..d8b2a93 100644 --- a/examples/plot_otda_d2.py +++ b/examples/domain-adaptation/plot_otda_d2.py @@ -18,12 +18,14 @@ of what the transport methods are doing. # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import matplotlib.pylab as pl import ot import ot.plot ############################################################################## -# generate data +# Generate data # ------------- n_samples_source = 150 diff --git a/examples/domain-adaptation/plot_otda_jcpot.py b/examples/domain-adaptation/plot_otda_jcpot.py new file mode 100644 index 0000000..c495690 --- /dev/null +++ b/examples/domain-adaptation/plot_otda_jcpot.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +""" +======================== +OT for multi-source target shift +======================== + +This example introduces a target shift problem with two 2D source and 1 target domain. + +""" + +# Authors: Remi Flamary <remi.flamary@unice.fr> +# Ievgen Redko <ievgen.redko@univ-st-etienne.fr> +# +# License: MIT License + +import pylab as pl +import numpy as np +import ot +from ot.datasets import make_data_classif + +############################################################################## +# Generate data +# ------------- +n = 50 +sigma = 0.3 +np.random.seed(1985) + +p1 = .2 +dec1 = [0, 2] + +p2 = .9 +dec2 = [0, -2] + +pt = .4 +dect = [4, 0] + +xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1) +xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2) +xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect) + +all_Xr = [xs1, xs2] +all_Yr = [ys1, ys2] +# %% + +da = 1.5 + + +def plot_ax(dec, name): + pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5) + pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5) + pl.text(dec[0] - .5, dec[1] + 2, name) + + +############################################################################## +# Fig 1 : plots source and target samples +# --------------------------------------- + +pl.figure(1) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, + label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1)) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, + label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2)) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, + label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt)) +pl.title('Data') + +pl.legend() +pl.axis('equal') +pl.axis('off') + +############################################################################## +# Instantiate Sinkhorn transport algorithm and fit them for all source domains +# ---------------------------------------------------------------------------- +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean') + + +def print_G(G, xs, ys, xt): + for i in range(G.shape[0]): + for j in range(G.shape[1]): + if G[i, j] > 5e-4: + if ys[i]: + c = 'b' + else: + c = 'r' + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2) + + +############################################################################## +# Fig 2 : plot optimal couplings and transported samples +# ------------------------------------------------------ +pl.figure(2) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt) +print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + +pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') +pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + +pl.title('Independent OT') + +pl.legend() +pl.axis('equal') +pl.axis('off') + +############################################################################## +# Instantiate JCPOT adaptation algorithm and fit it +# ---------------------------------------------------------------------------- +otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) +otda.fit(all_Xr, all_Yr, xt) + +ws1 = otda.proportions_.dot(otda.log_['D2'][0]) +ws2 = otda.proportions_.dot(otda.log_['D2'][1]) + +pl.figure(3) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + +pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') +pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + +pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1])) + +pl.legend() +pl.axis('equal') +pl.axis('off') + +############################################################################## +# Run oracle transport algorithm with known proportions +# ---------------------------------------------------------------------------- +h_res = np.array([1 - pt, pt]) + +ws1 = h_res.dot(otda.log_['D2'][0]) +ws2 = h_res.dot(otda.log_['D2'][1]) + +pl.figure(4) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + +pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') +pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + +pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1])) + +pl.legend() +pl.axis('equal') +pl.axis('off') +pl.show() diff --git a/examples/domain-adaptation/plot_otda_laplacian.py b/examples/domain-adaptation/plot_otda_laplacian.py new file mode 100644 index 0000000..67c8f67 --- /dev/null +++ b/examples/domain-adaptation/plot_otda_laplacian.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" +====================================================== +OT with Laplacian regularization for domain adaptation +====================================================== + +This example introduces a domain adaptation in a 2D setting and OTDA +approach with Laplacian regularization. + +""" + +# Authors: Ievgen Redko <ievgen.redko@univ-st-etienne.fr> + +# License: MIT License + +import matplotlib.pylab as pl +import ot + +############################################################################## +# Generate data +# ------------- + +n_source_samples = 150 +n_target_samples = 150 + +Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples) +Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples) + + +############################################################################## +# Instantiate the different transport algorithms and fit them +# ----------------------------------------------------------- + +# EMD Transport +ot_emd = ot.da.EMDTransport() +ot_emd.fit(Xs=Xs, Xt=Xt) + +# Sinkhorn Transport +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.01) +ot_sinkhorn.fit(Xs=Xs, Xt=Xt) + +# EMD Transport with Laplacian regularization +ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1) +ot_emd_laplace.fit(Xs=Xs, Xt=Xt) + +# transport source samples onto target samples +transp_Xs_emd = ot_emd.transform(Xs=Xs) +transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) +transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs) + +############################################################################## +# Fig 1 : plots source and target samples +# --------------------------------------- + +pl.figure(1, figsize=(10, 5)) +pl.subplot(1, 2, 1) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.xticks([]) +pl.yticks([]) +pl.legend(loc=0) +pl.title('Source samples') + +pl.subplot(1, 2, 2) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.xticks([]) +pl.yticks([]) +pl.legend(loc=0) +pl.title('Target samples') +pl.tight_layout() + + +############################################################################## +# Fig 2 : plot optimal couplings and transported samples +# ------------------------------------------------------ + +param_img = {'interpolation': 'nearest'} + +pl.figure(2, figsize=(15, 8)) +pl.subplot(2, 3, 1) +pl.imshow(ot_emd.coupling_, **param_img) +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nEMDTransport') + +pl.figure(2, figsize=(15, 8)) +pl.subplot(2, 3, 2) +pl.imshow(ot_sinkhorn.coupling_, **param_img) +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nSinkhornTransport') + +pl.subplot(2, 3, 3) +pl.imshow(ot_emd_laplace.coupling_, **param_img) +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nEMDLaplaceTransport') + +pl.subplot(2, 3, 4) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.3) +pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.xticks([]) +pl.yticks([]) +pl.title('Transported samples\nEmdTransport') +pl.legend(loc="lower left") + +pl.subplot(2, 3, 5) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.3) +pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.xticks([]) +pl.yticks([]) +pl.title('Transported samples\nSinkhornTransport') + +pl.subplot(2, 3, 6) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.3) +pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.xticks([]) +pl.yticks([]) +pl.title('Transported samples\nEMDLaplaceTransport') +pl.tight_layout() + +pl.show() diff --git a/examples/plot_otda_linear_mapping.py b/examples/domain-adaptation/plot_otda_linear_mapping.py index c65bd4f..dbf16b8 100644 --- a/examples/plot_otda_linear_mapping.py +++ b/examples/domain-adaptation/plot_otda_linear_mapping.py @@ -12,6 +12,8 @@ Linear OT mapping estimation # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import numpy as np import pylab as pl import ot @@ -92,8 +94,8 @@ def minmax(I): # Loading images -I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 +I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 +I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 X1 = im2mat(I1) diff --git a/examples/plot_otda_mapping.py b/examples/domain-adaptation/plot_otda_mapping.py index 5880adf..d21d3c9 100644 --- a/examples/plot_otda_mapping.py +++ b/examples/domain-adaptation/plot_otda_mapping.py @@ -9,8 +9,8 @@ time both the coupling transport and approximate the transport map with either a linear or a kernelized mapping as introduced in [8]. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, - "Mapping estimation for discrete optimal transport", - Neural Information Processing Systems (NIPS), 2016. +"Mapping estimation for discrete optimal transport", +Neural Information Processing Systems (NIPS), 2016. """ # Authors: Remi Flamary <remi.flamary@unice.fr> @@ -18,6 +18,8 @@ a linear or a kernelized mapping as introduced in [8]. # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import numpy as np import matplotlib.pylab as pl import ot diff --git a/examples/plot_otda_mapping_colors_images.py b/examples/domain-adaptation/plot_otda_mapping_colors_images.py index a20eca8..ee5c8b0 100644 --- a/examples/plot_otda_mapping_colors_images.py +++ b/examples/domain-adaptation/plot_otda_mapping_colors_images.py @@ -8,11 +8,10 @@ OT for domain adaptation with image color adaptation [6] with mapping estimation [8]. [6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized - discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), - 1853-1882. +discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. + [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for - discrete optimal transport", Neural Information Processing Systems (NIPS), - 2016. +discrete optimal transport", Neural Information Processing Systems (NIPS), 2016. """ @@ -21,8 +20,9 @@ estimation [8]. # # License: MIT License +# sphinx_gallery_thumbnail_number = 3 + import numpy as np -from scipy import ndimage import matplotlib.pylab as pl import ot @@ -48,8 +48,8 @@ def minmax(I): # ------------- # Loading images -I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 +I1 = pl.imread('../../data/ocean_day.jpg').astype(np.float64) / 256 +I2 = pl.imread('../../data/ocean_sunset.jpg').astype(np.float64) / 256 X1 = im2mat(I1) diff --git a/examples/plot_otda_semi_supervised.py b/examples/domain-adaptation/plot_otda_semi_supervised.py index 8a67720..478c3b8 100644 --- a/examples/plot_otda_semi_supervised.py +++ b/examples/domain-adaptation/plot_otda_semi_supervised.py @@ -18,6 +18,8 @@ of what the transport methods are doing. # # License: MIT License +# sphinx_gallery_thumbnail_number = 3 + import matplotlib.pylab as pl import ot diff --git a/examples/gromov/README.txt b/examples/gromov/README.txt new file mode 100644 index 0000000..9cc9c64 --- /dev/null +++ b/examples/gromov/README.txt @@ -0,0 +1,4 @@ + + +Gromov and Fused-Gromov-Wasserstein +-----------------------------------
\ No newline at end of file diff --git a/examples/plot_barycenter_fgw.py b/examples/gromov/plot_barycenter_fgw.py index 77b0370..3f81765 100644 --- a/examples/plot_barycenter_fgw.py +++ b/examples/gromov/plot_barycenter_fgw.py @@ -4,14 +4,15 @@ Plot graphs' barycenter using FGW ================================= -This example illustrates the computation barycenter of labeled graphs using FGW +This example illustrates the computation barycenter of labeled graphs using +FGW [18]. Requires networkx >=2 -.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. +[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +and Courty Nicolas +"Optimal Transport for structured data with application on graphs" +International Conference on Machine Learning (ICML). 2019. """ diff --git a/examples/plot_fgw.py b/examples/gromov/plot_fgw.py index 43efc94..97fe619 100644 --- a/examples/plot_fgw.py +++ b/examples/gromov/plot_fgw.py @@ -4,12 +4,12 @@ Plot Fused-gromov-Wasserstein ============================== -This example illustrates the computation of FGW for 1D measures[18]. +This example illustrates the computation of FGW for 1D measures [18]. -.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain - and Courty Nicolas - "Optimal Transport for structured data with application on graphs" - International Conference on Machine Learning (ICML). 2019. +[18] Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain +and Courty Nicolas +"Optimal Transport for structured data with application on graphs" +International Conference on Machine Learning (ICML). 2019. """ @@ -17,6 +17,8 @@ This example illustrates the computation of FGW for 1D measures[18]. # # License: MIT License +# sphinx_gallery_thumbnail_number = 3 + import matplotlib.pyplot as pl import numpy as np import ot @@ -60,14 +62,14 @@ pl.subplot(2, 1, 1) pl.scatter(ys, xs, c=phi, s=70) pl.ylabel('Feature value a', fontsize=20) -pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1) +pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, y=1) pl.xticks(()) pl.yticks(()) pl.subplot(2, 1, 2) pl.scatter(yt, xt, c=phi2, s=70) pl.xlabel('coordinates x/y', fontsize=25) pl.ylabel('Feature value b', fontsize=20) -pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1) +pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, y=1) pl.yticks(()) pl.tight_layout() pl.show() diff --git a/examples/plot_gromov.py b/examples/gromov/plot_gromov.py index deb2f86..deb2f86 100644 --- a/examples/plot_gromov.py +++ b/examples/gromov/plot_gromov.py diff --git a/examples/plot_gromov_barycenter.py b/examples/gromov/plot_gromov_barycenter.py index 58fc51a..f6f031a 100755 --- a/examples/plot_gromov_barycenter.py +++ b/examples/gromov/plot_gromov_barycenter.py @@ -17,7 +17,6 @@ computation in POT. import numpy as np
import scipy as sp
-import scipy.ndimage as spi
import matplotlib.pylab as pl
from sklearn import manifold
from sklearn.decomposition import PCA
@@ -90,10 +89,10 @@ def im2mat(I): return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
-square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
-cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
-triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
-star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
+square = pl.imread('../../data/square.png').astype(np.float64)[:, :, 2]
+cross = pl.imread('../../data/cross.png').astype(np.float64)[:, :, 2]
+triangle = pl.imread('../../data/triangle.png').astype(np.float64)[:, :, 2]
+star = pl.imread('../../data/star.png').astype(np.float64)[:, :, 2]
shapes = [square, cross, triangle, star]
diff --git a/examples/others/README.txt b/examples/others/README.txt new file mode 100644 index 0000000..df4c697 --- /dev/null +++ b/examples/others/README.txt @@ -0,0 +1,5 @@ + + + +Other OT problems +-----------------
\ No newline at end of file diff --git a/examples/plot_WDA.py b/examples/others/plot_WDA.py index 93cc237..bdfa57d 100644 --- a/examples/plot_WDA.py +++ b/examples/others/plot_WDA.py @@ -16,6 +16,8 @@ Wasserstein Discriminant Analysis. # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import numpy as np import matplotlib.pylab as pl @@ -31,6 +33,8 @@ from ot.dr import wda, fda n = 1000 # nb samples in source and target datasets nz = 0.2 +np.random.seed(1) + # generate circle dataset t = np.random.rand(n) * 2 * np.pi ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 @@ -86,7 +90,11 @@ reg = 1e0 k = 10 maxiter = 100 -Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) +P0 = np.random.randn(xs.shape[1], p) + +P0 /= np.sqrt(np.sum(P0**2, 0, keepdims=True)) + +Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter, P0=P0) ############################################################################## diff --git a/examples/plot_OT_1D.py b/examples/plot_OT_1D.py index f33e2a4..15ead96 100644 --- a/examples/plot_OT_1D.py +++ b/examples/plot_OT_1D.py @@ -12,6 +12,7 @@ and their visualization. # Author: Remi Flamary <remi.flamary@unice.fr> # # License: MIT License +# sphinx_gallery_thumbnail_number = 3 import numpy as np import matplotlib.pylab as pl diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index b690751..75cd295 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -13,6 +13,8 @@ and their visualization. # # License: MIT License +# sphinx_gallery_thumbnail_number = 6 + import numpy as np import matplotlib.pylab as pl import ot diff --git a/examples/plot_OT_2D_samples.py b/examples/plot_OT_2D_samples.py index 63126ba..1544e82 100644 --- a/examples/plot_OT_2D_samples.py +++ b/examples/plot_OT_2D_samples.py @@ -14,6 +14,8 @@ sum of diracs. The OT matrix is plotted with the samples. # # License: MIT License +# sphinx_gallery_thumbnail_number = 4 + import numpy as np import matplotlib.pylab as pl import ot diff --git a/examples/plot_OT_L1_vs_L2.py b/examples/plot_OT_L1_vs_L2.py index 37b429f..60353ab 100644 --- a/examples/plot_OT_L1_vs_L2.py +++ b/examples/plot_OT_L1_vs_L2.py @@ -16,6 +16,8 @@ https://arxiv.org/pdf/1706.07650.pdf # # License: MIT License +# sphinx_gallery_thumbnail_number = 3 + import numpy as np import matplotlib.pylab as pl import ot diff --git a/examples/plot_compute_emd.py b/examples/plot_compute_emd.py index 7ed2b01..527a847 100644 --- a/examples/plot_compute_emd.py +++ b/examples/plot_compute_emd.py @@ -4,8 +4,8 @@ Plot multiple EMD ================= -Shows how to compute multiple EMD and Sinkhorn with two differnt -ground metrics and plot their values for diffeent distributions. +Shows how to compute multiple EMD and Sinkhorn with two different +ground metrics and plot their values for different distributions. """ @@ -14,6 +14,8 @@ ground metrics and plot their values for diffeent distributions. # # License: MIT License +# sphinx_gallery_thumbnail_number = 3 + import numpy as np import matplotlib.pylab as pl import ot diff --git a/examples/plot_optim_OTreg.py b/examples/plot_optim_OTreg.py index 2c58def..5eb15bd 100644 --- a/examples/plot_optim_OTreg.py +++ b/examples/plot_optim_OTreg.py @@ -6,7 +6,7 @@ Regularized OT with generic solver Illustrates the use of the generic solver for regularized OT with user-designed regularization term. It uses Conditional gradient as in [6] and -generalized Conditional Gradient as proposed in [5][7]. +generalized Conditional Gradient as proposed in [5,7]. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for @@ -14,8 +14,8 @@ Domain Adaptation, in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). -Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, -7(3), 1853-1882. +Regularized discrete optimal transport. SIAM Journal on Imaging +Sciences, 7(3), 1853-1882. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. @@ -24,6 +24,7 @@ arXiv preprint arXiv:1510.06567. """ +# sphinx_gallery_thumbnail_number = 4 import numpy as np import matplotlib.pylab as pl diff --git a/examples/plot_screenkhorn_1D.py b/examples/plot_screenkhorn_1D.py new file mode 100644 index 0000000..785642a --- /dev/null +++ b/examples/plot_screenkhorn_1D.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +""" +=============================== +1D Screened optimal transport +=============================== + +This example illustrates the computation of Screenkhorn [26]. + +[26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). +Screening Sinkhorn Algorithm for Regularized Optimal Transport, +Advances in Neural Information Processing Systems 33 (NeurIPS). +""" + +# Author: Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot.plot +from ot.datasets import make_1D_gauss as gauss +from ot.bregman import screenkhorn + +############################################################################## +# Generate data +# ------------- + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +# plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + +############################################################################## +# Solve Screenkhorn +# ----------------------- + +# Screenkhorn +lambd = 2e-03 # entropy parameter +ns_budget = 30 # budget number of points to be keeped in the source distribution +nt_budget = 30 # budget number of points to be keeped in the target distribution + +G_screen = screenkhorn(a, b, M, lambd, ns_budget, nt_budget, uniform=False, restricted=True, verbose=True) +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, G_screen, 'OT matrix Screenkhorn') +pl.show() diff --git a/examples/plot_stochastic.py b/examples/plot_stochastic.py index 742f8d9..3a1ef31 100644 --- a/examples/plot_stochastic.py +++ b/examples/plot_stochastic.py @@ -1,10 +1,18 @@ """ -========================== +=================== Stochastic examples -========================== +=================== This example is designed to show how to use the stochatic optimization -algorithms for descrete and semicontinous measures from the POT library. +algorithms for discrete and semi-continuous measures from the POT library. + +[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. +Stochastic Optimization for Large-scale Optimal Transport. +Advances in Neural Information Processing Systems (2016). + +[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A. & +Blondel, M. Large-scale Optimal Transport and Mapping Estimation. +International Conference on Learning Representation (2018) """ @@ -19,16 +27,14 @@ import ot.plot ############################################################################# -# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM -############################################################################# -############################################################################# -# DISCRETE CASE: +# Compute the Transportation Matrix for the Semi-Dual Problem +# ----------------------------------------------------------- # -# Sample two discrete measures for the discrete case -# --------------------------------------------- +# Discrete case +# ````````````` # -# Define 2 discrete measures a and b, the points where are defined the source -# and the target measures and finally the cost matrix c. +# Sample two discrete measures for the discrete case and compute their cost +# matrix c. n_source = 7 n_target = 4 @@ -44,12 +50,7 @@ Y_target = rng.randn(n_target, 2) M = ot.dist(X_source, Y_target) ############################################################################# -# # Call the "SAG" method to find the transportation matrix in the discrete case -# --------------------------------------------- -# -# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the -# results. method = "SAG" sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, @@ -57,14 +58,12 @@ sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, print(sag_pi) ############################################################################# -# SEMICONTINOUS CASE: +# Semi-Continuous Case +# ```````````````````` # # Sample one general measure a, one discrete measures b for the semicontinous -# case -# --------------------------------------------- -# -# Define one general measure a, one discrete measures b, the points where -# are defined the source and the target measures and finally the cost matrix c. +# case, the points where source and target measures are defined and compute the +# cost matrix. n_source = 7 n_target = 4 @@ -81,13 +80,8 @@ Y_target = rng.randn(n_target, 2) M = ot.dist(X_source, Y_target) ############################################################################# -# # Call the "ASGD" method to find the transportation matrix in the semicontinous -# case -# --------------------------------------------- -# -# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the -# results. +# case. method = "ASGD" asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, @@ -96,23 +90,17 @@ print(log_asgd['alpha'], log_asgd['beta']) print(asgd_pi) ############################################################################# -# # Compare the results with the Sinkhorn algorithm -# --------------------------------------------- -# -# Call the Sinkhorn algorithm from POT sinkhorn_pi = ot.sinkhorn(a, b, M, reg) print(sinkhorn_pi) ############################################################################## -# PLOT TRANSPORTATION MATRIX -############################################################################## - -############################################################################## -# Plot SAG results -# ---------------- +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SAG pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG') @@ -120,8 +108,7 @@ pl.show() ############################################################################## -# Plot ASGD results -# ----------------- +# For ASGD pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD') @@ -129,8 +116,7 @@ pl.show() ############################################################################## -# Plot Sinkhorn results -# --------------------- +# For Sinkhorn pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') @@ -138,17 +124,14 @@ pl.show() ############################################################################# -# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM -############################################################################# -############################################################################# -# SEMICONTINOUS CASE: +# Compute the Transportation Matrix for the Dual Problem +# ------------------------------------------------------ # -# Sample one general measure a, one discrete measures b for the semicontinous -# case -# --------------------------------------------- +# Semi-continuous case +# ```````````````````` # -# Define one general measure a, one discrete measures b, the points where -# are defined the source and the target measures and finally the cost matrix c. +# Sample one general measure a, one discrete measures b for the semi-continuous +# case and compute the cost matrix c. n_source = 7 n_target = 4 @@ -169,10 +152,7 @@ M = ot.dist(X_source, Y_target) ############################################################################# # # Call the "SGD" dual method to find the transportation matrix in the -# semicontinous case -# --------------------------------------------- -# -# Call ot.solve_dual_entropic and plot the results. +# semi-continuous case sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, @@ -183,7 +163,7 @@ print(sgd_dual_pi) ############################################################################# # # Compare the results with the Sinkhorn algorithm -# --------------------------------------------- +# ``````````````````````````````````````````````` # # Call the Sinkhorn algorithm from POT @@ -191,8 +171,10 @@ sinkhorn_pi = ot.sinkhorn(a, b, M, reg) print(sinkhorn_pi) ############################################################################## -# Plot SGD results -# ----------------- +# Plot Transportation Matrices +# ```````````````````````````` +# +# For SGD pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD') @@ -200,8 +182,7 @@ pl.show() ############################################################################## -# Plot Sinkhorn results -# --------------------- +# For Sinkhorn pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn') diff --git a/examples/unbalanced-partial/README.txt b/examples/unbalanced-partial/README.txt new file mode 100644 index 0000000..2f404f0 --- /dev/null +++ b/examples/unbalanced-partial/README.txt @@ -0,0 +1,3 @@ + +Unbalanced and Partial OT +-------------------------
\ No newline at end of file diff --git a/examples/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index 2ea8b05..2ea8b05 100644 --- a/examples/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py diff --git a/examples/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py index c8d9d3b..931798b 100644 --- a/examples/plot_UOT_barycenter_1D.py +++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py @@ -16,6 +16,8 @@ as proposed in [10] for Unbalanced inputs. # # License: MIT License +# sphinx_gallery_thumbnail_number = 2 + import numpy as np import matplotlib.pylab as pl import ot @@ -77,7 +79,7 @@ bary_l2 = A.dot(weights) reg = 1e-3 alpha = 1. -bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) +bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights) pl.figure(2) pl.clf() @@ -111,7 +113,7 @@ for i in range(0, n_weight): weight = weight_list[i] weights = np.array([1 - weight, weight]) B_l2[:, i] = A.dot(weights) - B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) + B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights=weights) # plot interpolation diff --git a/examples/unbalanced-partial/plot_partial_wass_and_gromov.py b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py new file mode 100755 index 0000000..0c5cbf9 --- /dev/null +++ b/examples/unbalanced-partial/plot_partial_wass_and_gromov.py @@ -0,0 +1,165 @@ +# -*- coding: utf-8 -*-
+"""
+==================================================
+Partial Wasserstein and Gromov-Wasserstein example
+==================================================
+
+This example is designed to show how to use the Partial (Gromov-)Wassertsein
+distance computation in POT.
+"""
+
+# Author: Laetitia Chapel <laetitia.chapel@irisa.fr>
+# License: MIT License
+
+# sphinx_gallery_thumbnail_number = 2
+
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+#############################################################################
+#
+# Sample two 2D Gaussian distributions and plot them
+# --------------------------------------------------
+#
+# For demonstration purpose, we sample two Gaussian distributions in 2-d
+# spaces and add some random noise.
+
+
+n_samples = 20 # nb samples (gaussian)
+n_noise = 20 # nb of samples (noise)
+
+mu = np.array([0, 0])
+cov = np.array([[1, 0], [0, 2]])
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+xs = np.append(xs, (np.random.rand(n_noise, 2) + 1) * 4).reshape((-1, 2))
+xt = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov)
+xt = np.append(xt, (np.random.rand(n_noise, 2) + 1) * -3).reshape((-1, 2))
+
+M = sp.spatial.distance.cdist(xs, xt)
+
+fig = pl.figure()
+ax1 = fig.add_subplot(131)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(132)
+ax2.scatter(xt[:, 0], xt[:, 1], color='r')
+ax3 = fig.add_subplot(133)
+ax3.imshow(M)
+pl.show()
+
+#############################################################################
+#
+# Compute partial Wasserstein plans and distance
+# ----------------------------------------------
+
+p = ot.unif(n_samples + n_noise)
+q = ot.unif(n_samples + n_noise)
+
+w0, log0 = ot.partial.partial_wasserstein(p, q, M, m=0.5, log=True)
+w, log = ot.partial.entropic_partial_wasserstein(p, q, M, reg=0.1, m=0.5,
+ log=True)
+
+print('Partial Wasserstein distance (m = 0.5): ' + str(log0['partial_w_dist']))
+print('Entropic partial Wasserstein distance (m = 0.5): ' +
+ str(log['partial_w_dist']))
+
+pl.figure(1, (10, 5))
+pl.subplot(1, 2, 1)
+pl.imshow(w0, cmap='jet')
+pl.title('Partial Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(w, cmap='jet')
+pl.title('Entropic partial Wasserstein')
+pl.show()
+
+
+#############################################################################
+#
+# Sample one 2D and 3D Gaussian distributions and plot them
+# ---------------------------------------------------------
+#
+# The Gromov-Wasserstein distance allows to compute distances with samples that
+# do not belong to the same metric space. For demonstration purpose, we sample
+# two Gaussian distributions in 2- and 3-dimensional spaces.
+
+n_samples = 20 # nb samples
+n_noise = 10 # nb of samples (noise)
+
+p = ot.unif(n_samples + n_noise)
+q = ot.unif(n_samples + n_noise)
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([0, 0, 0])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+
+#############################################################################
+#
+# Compute partial Gromov-Wasserstein plans and distance
+# -----------------------------------------------------
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+# transport 100% of the mass
+print('-----m = 1')
+m = 1
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+
+print('Wasserstein distance (m = 1): ' + str(log0['partial_gw_dist']))
+print('Entropic Wasserstein distance (m = 1): ' + str(log['partial_gw_dist']))
+
+pl.figure(1, (10, 5))
+pl.title("mass to be transported m = 1")
+pl.subplot(1, 2, 1)
+pl.imshow(res0, cmap='jet')
+pl.title('Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(res, cmap='jet')
+pl.title('Entropic Wasserstein')
+pl.show()
+
+# transport 2/3 of the mass
+print('-----m = 2/3')
+m = 2 / 3
+res0, log0 = ot.partial.partial_gromov_wasserstein(C1, C2, p, q, m=m, log=True)
+res, log = ot.partial.entropic_partial_gromov_wasserstein(C1, C2, p, q, 10,
+ m=m, log=True)
+
+print('Partial Wasserstein distance (m = 2/3): ' +
+ str(log0['partial_gw_dist']))
+print('Entropic partial Wasserstein distance (m = 2/3): ' +
+ str(log['partial_gw_dist']))
+
+pl.figure(1, (10, 5))
+pl.title("mass to be transported m = 2/3")
+pl.subplot(1, 2, 1)
+pl.imshow(res0, cmap='jet')
+pl.title('Partial Wasserstein')
+pl.subplot(1, 2, 2)
+pl.imshow(res, cmap='jet')
+pl.title('Entropic partial Wasserstein')
+pl.show()
|