summaryrefslogtreecommitdiff
path: root/docs
diff options
context:
space:
mode:
Diffstat (limited to 'docs')
-rw-r--r--docs/Makefile216
-rw-r--r--docs/cache_nbrun1
-rw-r--r--docs/make.bat263
-rwxr-xr-xdocs/nb_build15
-rwxr-xr-xdocs/nb_run_conv82
-rw-r--r--docs/source/all.rst88
-rw-r--r--docs/source/auto_examples/auto_examples_jupyter.zipbin0 -> 148147 bytes
-rw-r--r--docs/source/auto_examples/auto_examples_python.zipbin0 -> 99229 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.pngbin0 -> 21372 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.pngbin0 -> 22051 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_005.pngbin0 -> 17080 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_007.pngbin0 -> 19019 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_001.pngbin0 -> 21372 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_002.pngbin0 -> 22051 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_005.pngbin0 -> 17080 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_007.pngbin0 -> 19405 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_009.pngbin0 -> 20630 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_010.pngbin0 -> 19232 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.pngbin0 -> 20785 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.pngbin0 -> 21134 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.pngbin0 -> 9704 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.pngbin0 -> 79153 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.pngbin0 -> 14611 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.pngbin0 -> 97487 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.pngbin0 -> 10846 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.pngbin0 -> 20361 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.pngbin0 -> 11773 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.pngbin0 -> 17253 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.pngbin0 -> 38780 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.pngbin0 -> 11710 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.pngbin0 -> 38849 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.pngbin0 -> 38780 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.pngbin0 -> 14186 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.pngbin0 -> 18765 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_009.pngbin0 -> 21300 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.pngbin0 -> 21369 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.pngbin0 -> 21239 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.pngbin0 -> 22051 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.pngbin0 -> 21288 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.pngbin0 -> 22177 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.pngbin0 -> 42539 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.pngbin0 -> 105997 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.pngbin0 -> 103234 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_WDA_001.pngbin0 -> 56604 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_WDA_003.pngbin0 -> 87031 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.pngbin0 -> 20581 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.pngbin0 -> 41624 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.pngbin0 -> 41624 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.pngbin0 -> 105765 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_005.pngbin0 -> 108756 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_006.pngbin0 -> 105765 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.pngbin0 -> 131827 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.pngbin0 -> 29423 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_001.pngbin0 -> 20581 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_002.pngbin0 -> 46114 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_003.pngbin0 -> 14405 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_004.pngbin0 -> 33271 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_006.pngbin0 -> 70940 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.pngbin0 -> 162681 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_compute_emd_003.pngbin0 -> 29345 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_compute_emd_004.pngbin0 -> 38817 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_convolutional_barycenter_001.pngbin0 -> 319138 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_fgw_004.pngbin0 -> 19490 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_fgw_010.pngbin0 -> 44747 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_fgw_011.pngbin0 -> 21337 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_free_support_barycenter_001.pngbin0 -> 31553 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_gromov_001.pngbin0 -> 44988 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_gromov_002.pngbin0 -> 17066 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_gromov_003.pngbin0 -> 18663 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_gromov_barycenter_001.pngbin0 -> 48271 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.pngbin0 -> 17080 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.pngbin0 -> 19084 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_006.pngbin0 -> 19317 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_008.pngbin0 -> 20484 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_classes_001.pngbin0 -> 50516 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_classes_003.pngbin0 -> 207861 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_001.pngbin0 -> 145014 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_003.pngbin0 -> 50472 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_005.pngbin0 -> 326766 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_d2_001.pngbin0 -> 134104 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_d2_003.pngbin0 -> 231768 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_d2_006.pngbin0 -> 107918 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_001.pngbin0 -> 29432 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_002.pngbin0 -> 53979 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_004.pngbin0 -> 591554 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_001.pngbin0 -> 38663 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_003.pngbin0 -> 76079 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.pngbin0 -> 165658 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.pngbin0 -> 80796 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.pngbin0 -> 512309 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.pngbin0 -> 158896 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.pngbin0 -> 36909 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.pngbin0 -> 80769 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_stochastic_004.pngbin0 -> 10450 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_stochastic_005.pngbin0 -> 10677 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_stochastic_006.pngbin0 -> 9131 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_stochastic_007.pngbin0 -> 9483 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_stochastic_008.pngbin0 -> 9131 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_smooth_thumb.pngbin0 -> 14983 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.pngbin0 -> 14983 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.pngbin0 -> 17987 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.pngbin0 -> 9377 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.pngbin0 -> 14761 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.pngbin0 -> 15099 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.pngbin0 -> 86417 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.pngbin0 -> 13542 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.pngbin0 -> 28694 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_lp_vs_entropic_thumb.pngbin0 -> 13542 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.pngbin0 -> 76133 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_convolutional_barycenter_thumb.pngbin0 -> 54369 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.pngbin0 -> 17541 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_free_support_barycenter_thumb.pngbin0 -> 19601 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.pngbin0 -> 28787 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.pngbin0 -> 25604 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.pngbin0 -> 3101 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.pngbin0 -> 23180 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.pngbin0 -> 49131 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.pngbin0 -> 48206 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_linear_mapping_thumb.pngbin0 -> 21399 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.pngbin0 -> 56216 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.pngbin0 -> 15931 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.pngbin0 -> 60596 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_stochastic_thumb.pngbin0 -> 17541 bytes
-rw-r--r--docs/source/auto_examples/index.rst535
-rw-r--r--docs/source/auto_examples/plot_OT_1D.ipynb126
-rw-r--r--docs/source/auto_examples/plot_OT_1D.py84
-rw-r--r--docs/source/auto_examples/plot_OT_1D.rst199
-rw-r--r--docs/source/auto_examples/plot_OT_1D_smooth.ipynb144
-rw-r--r--docs/source/auto_examples/plot_OT_1D_smooth.py110
-rw-r--r--docs/source/auto_examples/plot_OT_1D_smooth.rst242
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.ipynb144
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.py128
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.rst273
-rw-r--r--docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb126
-rw-r--r--docs/source/auto_examples/plot_OT_L1_vs_L2.py208
-rw-r--r--docs/source/auto_examples/plot_OT_L1_vs_L2.rst318
-rw-r--r--docs/source/auto_examples/plot_UOT_1D.ipynb108
-rw-r--r--docs/source/auto_examples/plot_UOT_1D.py76
-rw-r--r--docs/source/auto_examples/plot_UOT_1D.rst173
-rw-r--r--docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb126
-rw-r--r--docs/source/auto_examples/plot_UOT_barycenter_1D.py164
-rw-r--r--docs/source/auto_examples/plot_UOT_barycenter_1D.rst261
-rw-r--r--docs/source/auto_examples/plot_WDA.ipynb144
-rw-r--r--docs/source/auto_examples/plot_WDA.py127
-rw-r--r--docs/source/auto_examples/plot_WDA.rst244
-rw-r--r--docs/source/auto_examples/plot_barycenter_1D.ipynb126
-rw-r--r--docs/source/auto_examples/plot_barycenter_1D.py160
-rw-r--r--docs/source/auto_examples/plot_barycenter_1D.rst257
-rw-r--r--docs/source/auto_examples/plot_barycenter_fgw.ipynb126
-rw-r--r--docs/source/auto_examples/plot_barycenter_fgw.py184
-rw-r--r--docs/source/auto_examples/plot_barycenter_fgw.rst268
-rw-r--r--docs/source/auto_examples/plot_barycenter_lp_vs_entropic.ipynb108
-rw-r--r--docs/source/auto_examples/plot_barycenter_lp_vs_entropic.py281
-rw-r--r--docs/source/auto_examples/plot_barycenter_lp_vs_entropic.rst447
-rw-r--r--docs/source/auto_examples/plot_compute_emd.ipynb126
-rw-r--r--docs/source/auto_examples/plot_compute_emd.py102
-rw-r--r--docs/source/auto_examples/plot_compute_emd.rst189
-rw-r--r--docs/source/auto_examples/plot_convolutional_barycenter.ipynb90
-rw-r--r--docs/source/auto_examples/plot_convolutional_barycenter.py92
-rw-r--r--docs/source/auto_examples/plot_convolutional_barycenter.rst151
-rw-r--r--docs/source/auto_examples/plot_fgw.ipynb162
-rw-r--r--docs/source/auto_examples/plot_fgw.py173
-rw-r--r--docs/source/auto_examples/plot_fgw.rst297
-rw-r--r--docs/source/auto_examples/plot_free_support_barycenter.ipynb108
-rw-r--r--docs/source/auto_examples/plot_free_support_barycenter.py69
-rw-r--r--docs/source/auto_examples/plot_free_support_barycenter.rst140
-rw-r--r--docs/source/auto_examples/plot_gromov.ipynb126
-rw-r--r--docs/source/auto_examples/plot_gromov.py106
-rw-r--r--docs/source/auto_examples/plot_gromov.rst214
-rw-r--r--docs/source/auto_examples/plot_gromov_barycenter.ipynb126
-rw-r--r--docs/source/auto_examples/plot_gromov_barycenter.py248
-rw-r--r--docs/source/auto_examples/plot_gromov_barycenter.rst329
-rw-r--r--docs/source/auto_examples/plot_optim_OTreg.ipynb144
-rw-r--r--docs/source/auto_examples/plot_optim_OTreg.py129
-rw-r--r--docs/source/auto_examples/plot_optim_OTreg.rst663
-rw-r--r--docs/source/auto_examples/plot_otda_classes.ipynb126
-rw-r--r--docs/source/auto_examples/plot_otda_classes.py150
-rw-r--r--docs/source/auto_examples/plot_otda_classes.rst263
-rw-r--r--docs/source/auto_examples/plot_otda_color_images.ipynb144
-rw-r--r--docs/source/auto_examples/plot_otda_color_images.py165
-rw-r--r--docs/source/auto_examples/plot_otda_color_images.rst262
-rw-r--r--docs/source/auto_examples/plot_otda_d2.ipynb144
-rw-r--r--docs/source/auto_examples/plot_otda_d2.py172
-rw-r--r--docs/source/auto_examples/plot_otda_d2.rst269
-rw-r--r--docs/source/auto_examples/plot_otda_linear_mapping.ipynb180
-rw-r--r--docs/source/auto_examples/plot_otda_linear_mapping.py144
-rw-r--r--docs/source/auto_examples/plot_otda_linear_mapping.rst260
-rw-r--r--docs/source/auto_examples/plot_otda_mapping.ipynb126
-rw-r--r--docs/source/auto_examples/plot_otda_mapping.py125
-rw-r--r--docs/source/auto_examples/plot_otda_mapping.rst235
-rw-r--r--docs/source/auto_examples/plot_otda_mapping_colors_images.ipynb144
-rw-r--r--docs/source/auto_examples/plot_otda_mapping_colors_images.py174
-rw-r--r--docs/source/auto_examples/plot_otda_mapping_colors_images.rst310
-rw-r--r--docs/source/auto_examples/plot_otda_semi_supervised.ipynb144
-rw-r--r--docs/source/auto_examples/plot_otda_semi_supervised.py148
-rw-r--r--docs/source/auto_examples/plot_otda_semi_supervised.rst245
-rw-r--r--docs/source/auto_examples/plot_stochastic.ipynb295
-rw-r--r--docs/source/auto_examples/plot_stochastic.py208
-rw-r--r--docs/source/auto_examples/plot_stochastic.rst446
-rw-r--r--docs/source/auto_examples/searchindexbin0 -> 1892352 bytes
-rw-r--r--docs/source/conf.py343
-rw-r--r--docs/source/index.rst29
-rw-r--r--docs/source/quickstart.rst923
-rw-r--r--docs/source/readme.rst407
204 files changed, 17047 insertions, 0 deletions
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000..3511a59
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,216 @@
+# Makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line.
+SPHINXOPTS =
+SPHINXBUILD = sphinx-build
+PAPER =
+BUILDDIR = build
+
+# User-friendly check for sphinx-build
+ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
+$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
+endif
+
+# Internal variables.
+PAPEROPT_a4 = -D latex_paper_size=a4
+PAPEROPT_letter = -D latex_paper_size=letter
+ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source
+# the i18n builder cannot share the environment and doctrees with the others
+I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source
+
+.PHONY: help
+help:
+ @echo "Please use \`make <target>' where <target> is one of"
+ @echo " html to make standalone HTML files"
+ @echo " dirhtml to make HTML files named index.html in directories"
+ @echo " singlehtml to make a single large HTML file"
+ @echo " pickle to make pickle files"
+ @echo " json to make JSON files"
+ @echo " htmlhelp to make HTML files and a HTML help project"
+ @echo " qthelp to make HTML files and a qthelp project"
+ @echo " applehelp to make an Apple Help Book"
+ @echo " devhelp to make HTML files and a Devhelp project"
+ @echo " epub to make an epub"
+ @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
+ @echo " latexpdf to make LaTeX files and run them through pdflatex"
+ @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
+ @echo " text to make text files"
+ @echo " man to make manual pages"
+ @echo " texinfo to make Texinfo files"
+ @echo " info to make Texinfo files and run them through makeinfo"
+ @echo " gettext to make PO message catalogs"
+ @echo " changes to make an overview of all changed/added/deprecated items"
+ @echo " xml to make Docutils-native XML files"
+ @echo " pseudoxml to make pseudoxml-XML files for display purposes"
+ @echo " linkcheck to check all external links for integrity"
+ @echo " doctest to run all doctests embedded in the documentation (if enabled)"
+ @echo " coverage to run coverage check of the documentation (if enabled)"
+
+.PHONY: clean
+clean:
+ rm -rf $(BUILDDIR)/*
+
+.PHONY: html
+html:
+ $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
+ @echo
+ @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
+
+.PHONY: dirhtml
+dirhtml:
+ $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
+ @echo
+ @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
+
+.PHONY: singlehtml
+singlehtml:
+ $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
+ @echo
+ @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
+
+.PHONY: pickle
+pickle:
+ $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
+ @echo
+ @echo "Build finished; now you can process the pickle files."
+
+.PHONY: json
+json:
+ $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
+ @echo
+ @echo "Build finished; now you can process the JSON files."
+
+.PHONY: htmlhelp
+htmlhelp:
+ $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
+ @echo
+ @echo "Build finished; now you can run HTML Help Workshop with the" \
+ ".hhp project file in $(BUILDDIR)/htmlhelp."
+
+.PHONY: qthelp
+qthelp:
+ $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
+ @echo
+ @echo "Build finished; now you can run "qcollectiongenerator" with the" \
+ ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
+ @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/POT.qhcp"
+ @echo "To view the help file:"
+ @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/POT.qhc"
+
+.PHONY: applehelp
+applehelp:
+ $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp
+ @echo
+ @echo "Build finished. The help book is in $(BUILDDIR)/applehelp."
+ @echo "N.B. You won't be able to view it unless you put it in" \
+ "~/Library/Documentation/Help or install it in your application" \
+ "bundle."
+
+.PHONY: devhelp
+devhelp:
+ $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
+ @echo
+ @echo "Build finished."
+ @echo "To view the help file:"
+ @echo "# mkdir -p $$HOME/.local/share/devhelp/POT"
+ @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/POT"
+ @echo "# devhelp"
+
+.PHONY: epub
+epub:
+ $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
+ @echo
+ @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
+
+.PHONY: latex
+latex:
+ $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
+ @echo
+ @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
+ @echo "Run \`make' in that directory to run these through (pdf)latex" \
+ "(use \`make latexpdf' here to do that automatically)."
+
+.PHONY: latexpdf
+latexpdf:
+ $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
+ @echo "Running LaTeX files through pdflatex..."
+ $(MAKE) -C $(BUILDDIR)/latex all-pdf
+ @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
+
+.PHONY: latexpdfja
+latexpdfja:
+ $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
+ @echo "Running LaTeX files through platex and dvipdfmx..."
+ $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
+ @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
+
+.PHONY: text
+text:
+ $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
+ @echo
+ @echo "Build finished. The text files are in $(BUILDDIR)/text."
+
+.PHONY: man
+man:
+ $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
+ @echo
+ @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
+
+.PHONY: texinfo
+texinfo:
+ $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
+ @echo
+ @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
+ @echo "Run \`make' in that directory to run these through makeinfo" \
+ "(use \`make info' here to do that automatically)."
+
+.PHONY: info
+info:
+ $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
+ @echo "Running Texinfo files through makeinfo..."
+ make -C $(BUILDDIR)/texinfo info
+ @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
+
+.PHONY: gettext
+gettext:
+ $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
+ @echo
+ @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
+
+.PHONY: changes
+changes:
+ $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
+ @echo
+ @echo "The overview file is in $(BUILDDIR)/changes."
+
+.PHONY: linkcheck
+linkcheck:
+ $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
+ @echo
+ @echo "Link check complete; look for any errors in the above output " \
+ "or in $(BUILDDIR)/linkcheck/output.txt."
+
+.PHONY: doctest
+doctest:
+ $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
+ @echo "Testing of doctests in the sources finished, look at the " \
+ "results in $(BUILDDIR)/doctest/output.txt."
+
+.PHONY: coverage
+coverage:
+ $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage
+ @echo "Testing of coverage in the sources finished, look at the " \
+ "results in $(BUILDDIR)/coverage/python.txt."
+
+.PHONY: xml
+xml:
+ $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
+ @echo
+ @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
+
+.PHONY: pseudoxml
+pseudoxml:
+ $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
+ @echo
+ @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
diff --git a/docs/cache_nbrun b/docs/cache_nbrun
new file mode 100644
index 0000000..8a95023
--- /dev/null
+++ b/docs/cache_nbrun
@@ -0,0 +1 @@
+{"plot_otda_semi_supervised.ipynb": "f6dfb02ba2bbd939408ffcd22a3b007c", "plot_WDA.ipynb": "27f8de4c6d7db46497076523673eedfb", "plot_UOT_1D.ipynb": "fc7dd383e625597bd59fff03a8430c91", "plot_OT_L1_vs_L2.ipynb": "5d565b8aaf03be4309eba731127851dc", "plot_otda_color_images.ipynb": "f804d5806c7ac1a0901e4542b1eaa77b", "plot_fgw.ipynb": "2ba3e100e92ecf4dfbeb605de20b40ab", "plot_otda_d2.ipynb": "e6feae588103f2a8fab942e5f4eff483", "plot_compute_emd.ipynb": "f5cd71cad882ec157dc8222721e9820c", "plot_barycenter_fgw.ipynb": "e14100dd276bff3ffdfdf176f1b6b070", "plot_convolutional_barycenter.ipynb": "a72bb3716a1baaffd81ae267a673f9b6", "plot_optim_OTreg.ipynb": "481801bb0d133ef350a65179cf8f739a", "plot_barycenter_lp_vs_entropic.ipynb": "51833e8c76aaedeba9599ac7a30eb357", "plot_OT_1D_smooth.ipynb": "3a059103652225a0c78ea53895cf79e5", "plot_barycenter_1D.ipynb": "5f6fb8aebd8e2e91ebc77c923cb112b3", "plot_otda_mapping.ipynb": "2f1ebbdc0f855d9e2b7adf9edec24d25", "plot_OT_1D.ipynb": "b5348bdc561c07ec168a1622e5af4b93", "plot_gromov_barycenter.ipynb": "953e5047b886ec69ec621ec52f5e21d1", "plot_UOT_barycenter_1D.ipynb": "c72f0bfb6e1a79710dad3fef9f5c557c", "plot_otda_mapping_colors_images.ipynb": "cc8bf9a857f52e4a159fe71dfda19018", "plot_stochastic.ipynb": "e18253354c8c1d72567a4259eb1094f7", "plot_otda_linear_mapping.ipynb": "a472c767abe82020e0a58125a528785c", "plot_otda_classes.ipynb": "39087b6e98217851575f2271c22853a4", "plot_free_support_barycenter.ipynb": "246dd2feff4b233a4f1a553c5a202fdc", "plot_gromov.ipynb": "24f2aea489714d34779521f46d5e2c47", "plot_OT_2D_samples.ipynb": "912a77c5dd0fc0fafa03fac3d86f1502"} \ No newline at end of file
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 0000000..f2313df
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,263 @@
+@ECHO OFF
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set BUILDDIR=build
+set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% source
+set I18NSPHINXOPTS=%SPHINXOPTS% source
+if NOT "%PAPER%" == "" (
+ set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
+ set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
+)
+
+if "%1" == "" goto help
+
+if "%1" == "help" (
+ :help
+ echo.Please use `make ^<target^>` where ^<target^> is one of
+ echo. html to make standalone HTML files
+ echo. dirhtml to make HTML files named index.html in directories
+ echo. singlehtml to make a single large HTML file
+ echo. pickle to make pickle files
+ echo. json to make JSON files
+ echo. htmlhelp to make HTML files and a HTML help project
+ echo. qthelp to make HTML files and a qthelp project
+ echo. devhelp to make HTML files and a Devhelp project
+ echo. epub to make an epub
+ echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
+ echo. text to make text files
+ echo. man to make manual pages
+ echo. texinfo to make Texinfo files
+ echo. gettext to make PO message catalogs
+ echo. changes to make an overview over all changed/added/deprecated items
+ echo. xml to make Docutils-native XML files
+ echo. pseudoxml to make pseudoxml-XML files for display purposes
+ echo. linkcheck to check all external links for integrity
+ echo. doctest to run all doctests embedded in the documentation if enabled
+ echo. coverage to run coverage check of the documentation if enabled
+ goto end
+)
+
+if "%1" == "clean" (
+ for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
+ del /q /s %BUILDDIR%\*
+ goto end
+)
+
+
+REM Check if sphinx-build is available and fallback to Python version if any
+%SPHINXBUILD% 1>NUL 2>NUL
+if errorlevel 9009 goto sphinx_python
+goto sphinx_ok
+
+:sphinx_python
+
+set SPHINXBUILD=python -m sphinx.__init__
+%SPHINXBUILD% 2> nul
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+:sphinx_ok
+
+
+if "%1" == "html" (
+ %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The HTML pages are in %BUILDDIR%/html.
+ goto end
+)
+
+if "%1" == "dirhtml" (
+ %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
+ goto end
+)
+
+if "%1" == "singlehtml" (
+ %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
+ goto end
+)
+
+if "%1" == "pickle" (
+ %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can process the pickle files.
+ goto end
+)
+
+if "%1" == "json" (
+ %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can process the JSON files.
+ goto end
+)
+
+if "%1" == "htmlhelp" (
+ %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can run HTML Help Workshop with the ^
+.hhp project file in %BUILDDIR%/htmlhelp.
+ goto end
+)
+
+if "%1" == "qthelp" (
+ %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can run "qcollectiongenerator" with the ^
+.qhcp project file in %BUILDDIR%/qthelp, like this:
+ echo.^> qcollectiongenerator %BUILDDIR%\qthelp\POT.qhcp
+ echo.To view the help file:
+ echo.^> assistant -collectionFile %BUILDDIR%\qthelp\POT.ghc
+ goto end
+)
+
+if "%1" == "devhelp" (
+ %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished.
+ goto end
+)
+
+if "%1" == "epub" (
+ %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The epub file is in %BUILDDIR%/epub.
+ goto end
+)
+
+if "%1" == "latex" (
+ %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
+ goto end
+)
+
+if "%1" == "latexpdf" (
+ %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
+ cd %BUILDDIR%/latex
+ make all-pdf
+ cd %~dp0
+ echo.
+ echo.Build finished; the PDF files are in %BUILDDIR%/latex.
+ goto end
+)
+
+if "%1" == "latexpdfja" (
+ %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
+ cd %BUILDDIR%/latex
+ make all-pdf-ja
+ cd %~dp0
+ echo.
+ echo.Build finished; the PDF files are in %BUILDDIR%/latex.
+ goto end
+)
+
+if "%1" == "text" (
+ %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The text files are in %BUILDDIR%/text.
+ goto end
+)
+
+if "%1" == "man" (
+ %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The manual pages are in %BUILDDIR%/man.
+ goto end
+)
+
+if "%1" == "texinfo" (
+ %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
+ goto end
+)
+
+if "%1" == "gettext" (
+ %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
+ goto end
+)
+
+if "%1" == "changes" (
+ %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.The overview file is in %BUILDDIR%/changes.
+ goto end
+)
+
+if "%1" == "linkcheck" (
+ %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Link check complete; look for any errors in the above output ^
+or in %BUILDDIR%/linkcheck/output.txt.
+ goto end
+)
+
+if "%1" == "doctest" (
+ %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Testing of doctests in the sources finished, look at the ^
+results in %BUILDDIR%/doctest/output.txt.
+ goto end
+)
+
+if "%1" == "coverage" (
+ %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Testing of coverage in the sources finished, look at the ^
+results in %BUILDDIR%/coverage/python.txt.
+ goto end
+)
+
+if "%1" == "xml" (
+ %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The XML files are in %BUILDDIR%/xml.
+ goto end
+)
+
+if "%1" == "pseudoxml" (
+ %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
+ goto end
+)
+
+:end
diff --git a/docs/nb_build b/docs/nb_build
new file mode 100755
index 0000000..6abc6cf
--- /dev/null
+++ b/docs/nb_build
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+
+# remove comment
+sed -i "s/#'sphinx\_gallery/'sphinx\_gallery/" source/conf.py
+sed -i "s/sys.modules.update/#sys.modules.update/" source/conf.py
+
+make html
+
+# put comment again
+sed -i "s/'sphinx\_gallery/#'sphinx\_gallery/" source/conf.py
+sed -i "s/#sys.modules.update/sys.modules.update/" source/conf.py
+
+#rsync --out-format="%n" --update source/auto_examples/*.ipynb ../notebooks2
+./nb_run_conv
diff --git a/docs/nb_run_conv b/docs/nb_run_conv
new file mode 100755
index 0000000..ad5e432
--- /dev/null
+++ b/docs/nb_run_conv
@@ -0,0 +1,82 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+"""
+
+Convert sphinx gallery notebook from empty to image filled
+
+Created on Fri Sep 1 16:43:45 2017
+
+@author: rflamary
+"""
+
+import sys
+import json
+import glob
+import hashlib
+import subprocess
+
+import os
+
+cache_file='cache_nbrun'
+
+path_doc='source/auto_examples/'
+path_nb='../notebooks/'
+
+def load_json(fname):
+ try:
+ f=open(fname)
+ nb=json.load(f)
+ f.close()
+ except (OSError, IOError) :
+ nb={}
+ return nb
+
+def save_json(fname,nb):
+ f=open(fname,'w')
+ f.write(json.dumps(nb))
+ f.close()
+
+
+def md5(fname):
+ hash_md5 = hashlib.md5()
+ with open(fname, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ hash_md5.update(chunk)
+ return hash_md5.hexdigest()
+
+def to_update(fname,cache):
+ if fname in cache:
+ if md5(path_doc+fname)==cache[fname]:
+ res=False
+ else:
+ res=True
+ else:
+ res=True
+
+ return res
+
+def update(fname,cache):
+
+ # jupyter nbconvert --to notebook --execute mynotebook.ipynb --output targte
+ subprocess.check_call(['cp',path_doc+fname,path_nb])
+ print(' '.join(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace']))
+ subprocess.check_call(['jupyter','nbconvert','--to','notebook','--ExecutePreprocessor.timeout=600','--execute',path_nb+fname,'--inplace'])
+ cache[fname]=md5(path_doc+fname)
+
+
+
+cache=load_json(cache_file)
+
+lst_file=glob.glob(path_doc+'*.ipynb')
+
+lst_file=[os.path.basename(name) for name in lst_file]
+
+for fname in lst_file:
+ if to_update(fname,cache):
+ print('Updating file: {}'.format(fname))
+ update(fname,cache)
+ save_json(cache_file,cache)
+
+
+
+
diff --git a/docs/source/all.rst b/docs/source/all.rst
new file mode 100644
index 0000000..c968aa1
--- /dev/null
+++ b/docs/source/all.rst
@@ -0,0 +1,88 @@
+
+
+Python modules
+==============
+
+ot
+--
+
+.. automodule:: ot
+ :members:
+
+ot.lp
+-----
+.. automodule:: ot.lp
+ :members:
+
+ot.bregman
+----------
+
+.. automodule:: ot.bregman
+ :members:
+
+ot.smooth
+-----
+.. automodule:: ot.smooth
+ :members:
+
+ot.gromov
+----------
+
+.. automodule:: ot.gromov
+ :members:
+
+
+ot.optim
+--------
+
+.. automodule:: ot.optim
+ :members:
+
+ot.da
+--------
+
+.. automodule:: ot.da
+ :members:
+
+ot.gpu
+--------
+
+.. automodule:: ot.gpu
+ :members:
+
+ot.dr
+--------
+
+.. automodule:: ot.dr
+ :members:
+
+
+ot.utils
+--------
+
+.. automodule:: ot.utils
+ :members:
+
+ot.datasets
+-----------
+
+.. automodule:: ot.datasets
+ :members:
+
+ot.plot
+-------
+
+.. automodule:: ot.plot
+ :members:
+
+ot.stochastic
+-------------
+
+.. automodule:: ot.stochastic
+ :members:
+
+ot.unbalanced
+-------------
+
+.. automodule:: ot.unbalanced
+ :members:
diff --git a/docs/source/auto_examples/auto_examples_jupyter.zip b/docs/source/auto_examples/auto_examples_jupyter.zip
new file mode 100644
index 0000000..901195a
--- /dev/null
+++ b/docs/source/auto_examples/auto_examples_jupyter.zip
Binary files differ
diff --git a/docs/source/auto_examples/auto_examples_python.zip b/docs/source/auto_examples/auto_examples_python.zip
new file mode 100644
index 0000000..ded2613
--- /dev/null
+++ b/docs/source/auto_examples/auto_examples_python.zip
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png
new file mode 100644
index 0000000..6e74d89
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png
new file mode 100644
index 0000000..0407e44
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_005.png
new file mode 100644
index 0000000..4421bc7
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_007.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_007.png
new file mode 100644
index 0000000..2dbe49b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_007.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_001.png
new file mode 100644
index 0000000..6e74d89
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_002.png
new file mode 100644
index 0000000..0407e44
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_005.png
new file mode 100644
index 0000000..4421bc7
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_007.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_007.png
new file mode 100644
index 0000000..52638e3
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_007.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_009.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_009.png
new file mode 100644
index 0000000..c5078cf
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_009.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_010.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_010.png
new file mode 100644
index 0000000..58e87b6
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_smooth_010.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
new file mode 100644
index 0000000..a5bded7
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
new file mode 100644
index 0000000..1d90c2d
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
new file mode 100644
index 0000000..ea6a405
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
new file mode 100644
index 0000000..8bc46dc
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png
new file mode 100644
index 0000000..56d18ef
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png
new file mode 100644
index 0000000..5aef7d2
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png
new file mode 100644
index 0000000..bb8bd7c
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png
new file mode 100644
index 0000000..30cec7b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png
new file mode 100644
index 0000000..3b1a29e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png
new file mode 100644
index 0000000..5a33824
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png
new file mode 100644
index 0000000..4860d96
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png
new file mode 100644
index 0000000..6a21f35
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png
new file mode 100644
index 0000000..1108375
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png
new file mode 100644
index 0000000..4860d96
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.png
new file mode 100644
index 0000000..6e6f3b9
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.png
new file mode 100644
index 0000000..007d246
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_009.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_009.png
new file mode 100644
index 0000000..e1e9ba8
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_009.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.png
new file mode 100644
index 0000000..75ef929
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png
new file mode 100644
index 0000000..69ef5b7
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png
new file mode 100644
index 0000000..0407e44
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png
new file mode 100644
index 0000000..f58d383
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_1D_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png
new file mode 100644
index 0000000..ec8c51e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png
new file mode 100644
index 0000000..89ab265
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png
new file mode 100644
index 0000000..c6c49cb
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png
new file mode 100644
index 0000000..8870b10
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png b/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png
new file mode 100644
index 0000000..3524e19
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_WDA_003.png b/docs/source/auto_examples/images/sphx_glr_plot_WDA_003.png
new file mode 100644
index 0000000..819b974
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_WDA_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
new file mode 100644
index 0000000..3500812
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png
new file mode 100644
index 0000000..d8db85e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
new file mode 100644
index 0000000..d8db85e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png
new file mode 100644
index 0000000..bfa0873
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_005.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_005.png
new file mode 100644
index 0000000..81cee52
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_006.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_006.png
new file mode 100644
index 0000000..bfa0873
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png
new file mode 100644
index 0000000..77e1282
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png
new file mode 100644
index 0000000..ca6d7f8
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_001.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_001.png
new file mode 100644
index 0000000..3500812
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_002.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_002.png
new file mode 100644
index 0000000..37fef68
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_003.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_003.png
new file mode 100644
index 0000000..eb04b1a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_004.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_004.png
new file mode 100644
index 0000000..a9f44ba
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_006.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_006.png
new file mode 100644
index 0000000..e53928e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png
new file mode 100644
index 0000000..03e0b0e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_003.png b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_003.png
new file mode 100644
index 0000000..077db3e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_004.png b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_004.png
new file mode 100644
index 0000000..9ef7182
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_convolutional_barycenter_001.png b/docs/source/auto_examples/images/sphx_glr_plot_convolutional_barycenter_001.png
new file mode 100644
index 0000000..14a72a3
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_convolutional_barycenter_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_fgw_004.png b/docs/source/auto_examples/images/sphx_glr_plot_fgw_004.png
new file mode 100644
index 0000000..4e0df9f
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_fgw_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_fgw_010.png b/docs/source/auto_examples/images/sphx_glr_plot_fgw_010.png
new file mode 100644
index 0000000..d0e36e8
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_fgw_010.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_fgw_011.png b/docs/source/auto_examples/images/sphx_glr_plot_fgw_011.png
new file mode 100644
index 0000000..6d7e630
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_fgw_011.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_free_support_barycenter_001.png b/docs/source/auto_examples/images/sphx_glr_plot_free_support_barycenter_001.png
new file mode 100644
index 0000000..d7bc78a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_free_support_barycenter_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_gromov_001.png b/docs/source/auto_examples/images/sphx_glr_plot_gromov_001.png
new file mode 100644
index 0000000..2e9b38e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_gromov_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_gromov_002.png b/docs/source/auto_examples/images/sphx_glr_plot_gromov_002.png
new file mode 100644
index 0000000..343fd78
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_gromov_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_gromov_003.png b/docs/source/auto_examples/images/sphx_glr_plot_gromov_003.png
new file mode 100644
index 0000000..93e1def
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_gromov_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_gromov_barycenter_001.png b/docs/source/auto_examples/images/sphx_glr_plot_gromov_barycenter_001.png
new file mode 100644
index 0000000..0665c9b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_gromov_barycenter_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png
new file mode 100644
index 0000000..4421bc7
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png
new file mode 100644
index 0000000..bf7c076
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_006.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_006.png
new file mode 100644
index 0000000..afca192
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_008.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_008.png
new file mode 100644
index 0000000..daa2a8d
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_008.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_001.png
new file mode 100644
index 0000000..71ef350
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_003.png
new file mode 100644
index 0000000..3c33d5b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_001.png
new file mode 100644
index 0000000..7de991a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_003.png
new file mode 100644
index 0000000..aac929b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_005.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_005.png
new file mode 100644
index 0000000..5b8101b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_001.png
new file mode 100644
index 0000000..114871a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_003.png
new file mode 100644
index 0000000..78ac59b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_006.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_006.png
new file mode 100644
index 0000000..7385dcc
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_001.png
new file mode 100644
index 0000000..88796df
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_002.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_002.png
new file mode 100644
index 0000000..22b5d0c
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_004.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_004.png
new file mode 100644
index 0000000..ff10b72
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_linear_mapping_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_001.png
new file mode 100644
index 0000000..16a228a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_003.png
new file mode 100644
index 0000000..02fe3d6
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.png
new file mode 100644
index 0000000..d77e68a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.png
new file mode 100644
index 0000000..1199903
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.png
new file mode 100644
index 0000000..1c73e43
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.png
new file mode 100644
index 0000000..9b5ae7a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.png
new file mode 100644
index 0000000..26ab6f6
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.png
new file mode 100644
index 0000000..2b3bf0e
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_stochastic_004.png b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_004.png
new file mode 100644
index 0000000..8aada91
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_stochastic_005.png b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_005.png
new file mode 100644
index 0000000..42e5007
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_stochastic_006.png b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_006.png
new file mode 100644
index 0000000..335ea95
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_stochastic_007.png b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_007.png
new file mode 100644
index 0000000..cda643b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_007.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_stochastic_008.png b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_008.png
new file mode 100644
index 0000000..335ea95
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_stochastic_008.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_smooth_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_smooth_thumb.png
new file mode 100644
index 0000000..4679eb6
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_smooth_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png
new file mode 100644
index 0000000..4679eb6
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
new file mode 100644
index 0000000..ae33588
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png
new file mode 100644
index 0000000..cdf1208
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png
new file mode 100644
index 0000000..1d048f2
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png
new file mode 100644
index 0000000..999f175
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png
new file mode 100644
index 0000000..2316fcc
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png
new file mode 100644
index 0000000..c68e95f
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png
new file mode 100644
index 0000000..9c3244e
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_lp_vs_entropic_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_lp_vs_entropic_thumb.png
new file mode 100644
index 0000000..c68e95f
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_lp_vs_entropic_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png
new file mode 100644
index 0000000..4531351
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_convolutional_barycenter_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_convolutional_barycenter_thumb.png
new file mode 100644
index 0000000..af8aad2
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_convolutional_barycenter_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png
new file mode 100644
index 0000000..609339d
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_free_support_barycenter_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_free_support_barycenter_thumb.png
new file mode 100644
index 0000000..0861d4d
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_free_support_barycenter_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.png
new file mode 100644
index 0000000..df25b39
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png
new file mode 100644
index 0000000..6f250a4
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
new file mode 100644
index 0000000..cbc8e0f
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.png
new file mode 100644
index 0000000..ec78552
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.png
new file mode 100644
index 0000000..4d90437
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.png
new file mode 100644
index 0000000..4f8f72f
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_linear_mapping_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_linear_mapping_thumb.png
new file mode 100644
index 0000000..277950e
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_linear_mapping_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.png
new file mode 100644
index 0000000..61a5137
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.png
new file mode 100644
index 0000000..bd7c939
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.png
new file mode 100644
index 0000000..b683392
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_stochastic_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_stochastic_thumb.png
new file mode 100644
index 0000000..609339d
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_stochastic_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/index.rst b/docs/source/auto_examples/index.rst
new file mode 100644
index 0000000..fe6702d
--- /dev/null
+++ b/docs/source/auto_examples/index.rst
@@ -0,0 +1,535 @@
+:orphan:
+
+POT Examples
+============
+
+This is a gallery of all the POT example files.
+
+
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of EMD and Sinkhorn transport plans and their visualiz...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_OT_1D.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_OT_1D
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of Unbalanced Optimal transport using a Kullback-Leibl...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_UOT_1D_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_UOT_1D.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_UOT_1D
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="Illustrates the use of the generic solver for regularized OT with user-designed regularization ...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_optim_OTreg.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_optim_OTreg
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="Illustration of 2D Wasserstein barycenters if discributions that are weighted sum of diracs.">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_free_support_barycenter_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_free_support_barycenter.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_free_support_barycenter
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of EMD, Sinkhorn and smooth OT plans and their visuali...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_1D_smooth_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_OT_1D_smooth.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_OT_1D_smooth
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the Gromov-Wassertsein distance computation in POT....">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_gromov.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_gromov
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="Shows how to compute multiple EMD and Sinkhorn with two differnt ground metrics and plot their ...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_compute_emd.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_compute_emd
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to illustrate how the Convolutional Wasserstein Barycenter function of...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_convolutional_barycenter_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_convolutional_barycenter.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_convolutional_barycenter
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip=" ">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_linear_mapping_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_otda_linear_mapping.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_otda_linear_mapping
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrate the use of WDA as proposed in [11].">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_WDA.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_WDA
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="Illustration of 2D optimal transport between discributions that are weighted sum of diracs. The...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_OT_2D_samples.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_OT_2D_samples
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the stochatic optimization algorithms for descrete ...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_stochastic_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_stochastic.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_stochastic
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example presents a way of transferring colors between two images with Optimal Transport as...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_otda_color_images.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_otda_color_images
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of regularized Wassersyein Barycenter as proposed in [...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_barycenter_1D.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_barycenter_1D
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="OT for domain adaptation with image color adaptation [6] with mapping estimation [8].">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_otda_mapping_colors_images.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_otda_mapping_colors_images
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of regularized Wassersyein Barycenter as proposed in [...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_UOT_barycenter_1D_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_UOT_barycenter_1D.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_UOT_barycenter_1D
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example presents how to use MappingTransport to estimate at the same time both the couplin...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_otda_mapping.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_otda_mapping
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example introduces a semi supervised domain adaptation in a 2D setting. It explicits the p...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_otda_semi_supervised.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_otda_semi_supervised
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of FGW for 1D measures[18].">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_fgw_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_fgw.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_fgw
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example introduces a domain adaptation in a 2D setting and the 4 OTDA approaches currently...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_otda_classes.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_otda_classes
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example introduces a domain adaptation in a 2D setting. It explicits the problem of domain...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_otda_d2.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_otda_d2
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="2D OT on empirical distributio with different gound metric.">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_OT_L1_vs_L2.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_OT_L1_vs_L2
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of regularized Wasserstein Barycenter as proposed in [...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_barycenter_lp_vs_entropic_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_barycenter_lp_vs_entropic.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_barycenter_lp_vs_entropic
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation barycenter of labeled graphs using FGW">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_barycenter_fgw_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_barycenter_fgw.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_barycenter_fgw
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the Gromov-Wasserstein distance computation in POT....">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_gromov_barycenter.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_gromov_barycenter
+.. raw:: html
+
+ <div style='clear:both'></div>
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download all examples in Python source code: auto_examples_python.zip <//home/rflamary/PYTHON/POT/docs/source/auto_examples/auto_examples_python.zip>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download all examples in Jupyter notebooks: auto_examples_jupyter.zip <//home/rflamary/PYTHON/POT/docs/source/auto_examples/auto_examples_jupyter.zip>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_OT_1D.ipynb b/docs/source/auto_examples/plot_OT_1D.ipynb
new file mode 100644
index 0000000..bd0439e
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_1D.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 1D optimal transport\n\n\nThis example illustrates the computation of EMD and Sinkhorn transport plans\nand their visualization.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot\nfrom ot.datasets import make_1D_gauss as gauss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot distributions and loss matrix\n----------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve EMD\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% EMD\n\nG0 = ot.emd(a, b, M)\n\npl.figure(3, figsize=(5, 5))\not.plot.plot1D_mat(a, b, G0, 'OT matrix G0')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve Sinkhorn\n--------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Sinkhorn\n\nlambd = 1e-3\nGs = ot.sinkhorn(a, b, M, lambd, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OT_1D.py b/docs/source/auto_examples/plot_OT_1D.py
new file mode 100644
index 0000000..f33e2a4
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_1D.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+"""
+====================
+1D optimal transport
+====================
+
+This example illustrates the computation of EMD and Sinkhorn transport plans
+and their visualization.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+
+##############################################################################
+# Generate data
+# -------------
+
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+#%% plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve EMD
+# ---------
+
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+##############################################################################
+# Solve Sinkhorn
+# --------------
+
+
+#%% Sinkhorn
+
+lambd = 1e-3
+Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_OT_1D.rst b/docs/source/auto_examples/plot_OT_1D.rst
new file mode 100644
index 0000000..b97d67c
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_1D.rst
@@ -0,0 +1,199 @@
+
+
+.. _sphx_glr_auto_examples_plot_OT_1D.py:
+
+
+====================
+1D optimal transport
+====================
+
+This example illustrates the computation of EMD and Sinkhorn transport plans
+and their visualization.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ import ot.plot
+ from ot.datasets import make_1D_gauss as gauss
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=20, s=5) # m= mean, s= std
+ b = gauss(n, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ M /= M.max()
+
+
+
+
+
+
+
+
+Plot distributions and loss matrix
+----------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ pl.plot(x, a, 'b', label='Source distribution')
+ pl.plot(x, b, 'r', label='Target distribution')
+ pl.legend()
+
+ #%% plot distributions and loss matrix
+
+ pl.figure(2, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_002.png
+ :scale: 47
+
+
+
+
+Solve EMD
+---------
+
+
+
+.. code-block:: python
+
+
+
+ #%% EMD
+
+ G0 = ot.emd(a, b, M)
+
+ pl.figure(3, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_OT_1D_005.png
+ :align: center
+
+
+
+
+Solve Sinkhorn
+--------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% Sinkhorn
+
+ lambd = 1e-3
+ Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_OT_1D_007.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Err
+ -------------------
+ 0|8.187970e-02|
+ 10|3.460174e-02|
+ 20|6.633335e-03|
+ 30|9.797798e-04|
+ 40|1.389606e-04|
+ 50|1.959016e-05|
+ 60|2.759079e-06|
+ 70|3.885166e-07|
+ 80|5.470605e-08|
+ 90|7.702918e-09|
+ 100|1.084609e-09|
+ 110|1.527180e-10|
+
+
+**Total running time of the script:** ( 0 minutes 0.561 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_OT_1D.py <plot_OT_1D.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_OT_1D.ipynb <plot_OT_1D.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_OT_1D_smooth.ipynb b/docs/source/auto_examples/plot_OT_1D_smooth.ipynb
new file mode 100644
index 0000000..d523f6a
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_1D_smooth.ipynb
@@ -0,0 +1,144 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 1D smooth optimal transport\n\n\nThis example illustrates the computation of EMD, Sinkhorn and smooth OT plans\nand their visualization.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot\nfrom ot.datasets import make_1D_gauss as gauss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot distributions and loss matrix\n----------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve EMD\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% EMD\n\nG0 = ot.emd(a, b, M)\n\npl.figure(3, figsize=(5, 5))\not.plot.plot1D_mat(a, b, G0, 'OT matrix G0')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve Sinkhorn\n--------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Sinkhorn\n\nlambd = 2e-3\nGs = ot.sinkhorn(a, b, M, lambd, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')\n\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve Smooth OT\n--------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Smooth OT with KL regularization\n\nlambd = 2e-3\nGsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')\n\npl.figure(5, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')\n\npl.show()\n\n\n#%% Smooth OT with KL regularization\n\nlambd = 1e-1\nGsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')\n\npl.figure(6, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OT_1D_smooth.py b/docs/source/auto_examples/plot_OT_1D_smooth.py
new file mode 100644
index 0000000..b690751
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_1D_smooth.py
@@ -0,0 +1,110 @@
+# -*- coding: utf-8 -*-
+"""
+===========================
+1D smooth optimal transport
+===========================
+
+This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
+and their visualization.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+
+##############################################################################
+# Generate data
+# -------------
+
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+#%% plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve EMD
+# ---------
+
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+##############################################################################
+# Solve Sinkhorn
+# --------------
+
+
+#%% Sinkhorn
+
+lambd = 2e-3
+Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
+
+pl.show()
+
+##############################################################################
+# Solve Smooth OT
+# --------------
+
+
+#%% Smooth OT with KL regularization
+
+lambd = 2e-3
+Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')
+
+pl.figure(5, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')
+
+pl.show()
+
+
+#%% Smooth OT with KL regularization
+
+lambd = 1e-1
+Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
+
+pl.figure(6, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_OT_1D_smooth.rst b/docs/source/auto_examples/plot_OT_1D_smooth.rst
new file mode 100644
index 0000000..5a0ebd3
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_1D_smooth.rst
@@ -0,0 +1,242 @@
+
+
+.. _sphx_glr_auto_examples_plot_OT_1D_smooth.py:
+
+
+===========================
+1D smooth optimal transport
+===========================
+
+This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
+and their visualization.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ import ot.plot
+ from ot.datasets import make_1D_gauss as gauss
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=20, s=5) # m= mean, s= std
+ b = gauss(n, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ M /= M.max()
+
+
+
+
+
+
+
+
+Plot distributions and loss matrix
+----------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ pl.plot(x, a, 'b', label='Source distribution')
+ pl.plot(x, b, 'r', label='Target distribution')
+ pl.legend()
+
+ #%% plot distributions and loss matrix
+
+ pl.figure(2, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_smooth_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_smooth_002.png
+ :scale: 47
+
+
+
+
+Solve EMD
+---------
+
+
+
+.. code-block:: python
+
+
+
+ #%% EMD
+
+ G0 = ot.emd(a, b, M)
+
+ pl.figure(3, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_OT_1D_smooth_005.png
+ :align: center
+
+
+
+
+Solve Sinkhorn
+--------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% Sinkhorn
+
+ lambd = 2e-3
+ Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
+
+ pl.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_OT_1D_smooth_007.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Err
+ -------------------
+ 0|7.958844e-02|
+ 10|5.921715e-03|
+ 20|1.238266e-04|
+ 30|2.469780e-06|
+ 40|4.919966e-08|
+ 50|9.800197e-10|
+
+
+Solve Smooth OT
+--------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% Smooth OT with KL regularization
+
+ lambd = 2e-3
+ Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')
+
+ pl.figure(5, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')
+
+ pl.show()
+
+
+ #%% Smooth OT with KL regularization
+
+ lambd = 1e-1
+ Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
+
+ pl.figure(6, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')
+
+ pl.show()
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_smooth_009.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_smooth_010.png
+ :scale: 47
+
+
+
+
+**Total running time of the script:** ( 0 minutes 1.053 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_OT_1D_smooth.py <plot_OT_1D_smooth.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_OT_1D_smooth.ipynb <plot_OT_1D_smooth.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.ipynb b/docs/source/auto_examples/plot_OT_2D_samples.ipynb
new file mode 100644
index 0000000..dad138b
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_2D_samples.ipynb
@@ -0,0 +1,144 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 2D Optimal transport between empirical distributions\n\n\nIllustration of 2D optimal transport between discributions that are weighted\nsum of diracs. The OT matrix is plotted with the samples.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n# Kilian Fatras <kilian.fatras@irisa.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters and data generation\n\nn = 50 # nb samples\n\nmu_s = np.array([0, 0])\ncov_s = np.array([[1, 0], [0, 1]])\n\nmu_t = np.array([4, 4])\ncov_t = np.array([[1, -.8], [-.8, 1]])\n\nxs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)\nxt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)\n\na, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples\n\n# loss matrix\nM = ot.dist(xs, xt)\nM /= M.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot samples\n\npl.figure(1)\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('Source and target distributions')\n\npl.figure(2)\npl.imshow(M, interpolation='nearest')\npl.title('Cost matrix M')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute EMD\n-----------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% EMD\n\nG0 = ot.emd(a, b, M)\n\npl.figure(3)\npl.imshow(G0, interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute Sinkhorn\n----------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% sinkhorn\n\n# reg term\nlambd = 1e-3\n\nGs = ot.sinkhorn(a, b, M, lambd)\n\npl.figure(5)\npl.imshow(Gs, interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')\n\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Emprirical Sinkhorn\n----------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% sinkhorn\n\n# reg term\nlambd = 1e-3\n\nGes = ot.bregman.empirical_sinkhorn(xs, xt, lambd)\n\npl.figure(7)\npl.imshow(Ges, interpolation='nearest')\npl.title('OT matrix empirical sinkhorn')\n\npl.figure(8)\not.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn from samples')\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.py b/docs/source/auto_examples/plot_OT_2D_samples.py
new file mode 100644
index 0000000..63126ba
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_2D_samples.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+2D Optimal transport between empirical distributions
+====================================================
+
+Illustration of 2D optimal transport between discributions that are weighted
+sum of diracs. The OT matrix is plotted with the samples.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Kilian Fatras <kilian.fatras@irisa.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters and data generation
+
+n = 50 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
+
+xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+# loss matrix
+M = ot.dist(xs, xt)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot samples
+
+pl.figure(1)
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+pl.figure(2)
+pl.imshow(M, interpolation='nearest')
+pl.title('Cost matrix M')
+
+##############################################################################
+# Compute EMD
+# -----------
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+pl.figure(3)
+pl.imshow(G0, interpolation='nearest')
+pl.title('OT matrix G0')
+
+pl.figure(4)
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix with samples')
+
+
+##############################################################################
+# Compute Sinkhorn
+# ----------------
+
+#%% sinkhorn
+
+# reg term
+lambd = 1e-3
+
+Gs = ot.sinkhorn(a, b, M, lambd)
+
+pl.figure(5)
+pl.imshow(Gs, interpolation='nearest')
+pl.title('OT matrix sinkhorn')
+
+pl.figure(6)
+ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix Sinkhorn with samples')
+
+pl.show()
+
+
+##############################################################################
+# Emprirical Sinkhorn
+# ----------------
+
+#%% sinkhorn
+
+# reg term
+lambd = 1e-3
+
+Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
+
+pl.figure(7)
+pl.imshow(Ges, interpolation='nearest')
+pl.title('OT matrix empirical sinkhorn')
+
+pl.figure(8)
+ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.legend(loc=0)
+pl.title('OT matrix Sinkhorn from samples')
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.rst b/docs/source/auto_examples/plot_OT_2D_samples.rst
new file mode 100644
index 0000000..1f1d713
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_2D_samples.rst
@@ -0,0 +1,273 @@
+
+
+.. _sphx_glr_auto_examples_plot_OT_2D_samples.py:
+
+
+====================================================
+2D Optimal transport between empirical distributions
+====================================================
+
+Illustration of 2D optimal transport between discributions that are weighted
+sum of diracs. The OT matrix is plotted with the samples.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ # Kilian Fatras <kilian.fatras@irisa.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ import ot.plot
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters and data generation
+
+ n = 50 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ mu_t = np.array([4, 4])
+ cov_t = np.array([[1, -.8], [-.8, 1]])
+
+ xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
+ xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
+
+ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
+
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% plot samples
+
+ pl.figure(1)
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.legend(loc=0)
+ pl.title('Source and target distributions')
+
+ pl.figure(2)
+ pl.imshow(M, interpolation='nearest')
+ pl.title('Cost matrix M')
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
+ :scale: 47
+
+
+
+
+Compute EMD
+-----------
+
+
+
+.. code-block:: python
+
+
+ #%% EMD
+
+ G0 = ot.emd(a, b, M)
+
+ pl.figure(3)
+ pl.imshow(G0, interpolation='nearest')
+ pl.title('OT matrix G0')
+
+ pl.figure(4)
+ ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.legend(loc=0)
+ pl.title('OT matrix with samples')
+
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
+ :scale: 47
+
+
+
+
+Compute Sinkhorn
+----------------
+
+
+
+.. code-block:: python
+
+
+ #%% sinkhorn
+
+ # reg term
+ lambd = 1e-3
+
+ Gs = ot.sinkhorn(a, b, M, lambd)
+
+ pl.figure(5)
+ pl.imshow(Gs, interpolation='nearest')
+ pl.title('OT matrix sinkhorn')
+
+ pl.figure(6)
+ ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.legend(loc=0)
+ pl.title('OT matrix Sinkhorn with samples')
+
+ pl.show()
+
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png
+ :scale: 47
+
+
+
+
+Emprirical Sinkhorn
+----------------
+
+
+
+.. code-block:: python
+
+
+ #%% sinkhorn
+
+ # reg term
+ lambd = 1e-3
+
+ Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
+
+ pl.figure(7)
+ pl.imshow(Ges, interpolation='nearest')
+ pl.title('OT matrix empirical sinkhorn')
+
+ pl.figure(8)
+ ot.plot.plot2D_samples_mat(xs, xt, Ges, color=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.legend(loc=0)
+ pl.title('OT matrix Sinkhorn from samples')
+
+ pl.show()
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_013.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_014.png
+ :scale: 47
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ Warning: numerical errors at iteration 0
+
+
+**Total running time of the script:** ( 0 minutes 2.616 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_OT_2D_samples.py <plot_OT_2D_samples.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_OT_2D_samples.ipynb <plot_OT_2D_samples.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb b/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb
new file mode 100644
index 0000000..125d720
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 2D Optimal transport for different metrics\n\n\n2D OT on empirical distributio with different gound metric.\n\nStole the figure idea from Fig. 1 and 2 in\nhttps://arxiv.org/pdf/1706.07650.pdf\n\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Dataset 1 : uniform sampling\n----------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n = 20 # nb samples\nxs = np.zeros((n, 2))\nxs[:, 0] = np.arange(n) + 1\nxs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...\n\nxt = np.zeros((n, 2))\nxt[:, 1] = np.arange(n) + 1\n\na, b = ot.unif(n), ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM1 = ot.dist(xs, xt, metric='euclidean')\nM1 /= M1.max()\n\n# loss matrix\nM2 = ot.dist(xs, xt, metric='sqeuclidean')\nM2 /= M2.max()\n\n# loss matrix\nMp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))\nMp /= Mp.max()\n\n# Data\npl.figure(1, figsize=(7, 3))\npl.clf()\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\npl.title('Source and target distributions')\n\n\n# Cost matrices\npl.figure(2, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\npl.imshow(M1, interpolation='nearest')\npl.title('Euclidean cost')\n\npl.subplot(1, 3, 2)\npl.imshow(M2, interpolation='nearest')\npl.title('Squared Euclidean cost')\n\npl.subplot(1, 3, 3)\npl.imshow(Mp, interpolation='nearest')\npl.title('Sqrt Euclidean cost')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Dataset 1 : Plot OT Matrices\n----------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% EMD\nG1 = ot.emd(a, b, M1)\nG2 = ot.emd(a, b, M2)\nGp = ot.emd(a, b, Mp)\n\n# OT matrices\npl.figure(3, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\not.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT Euclidean')\n\npl.subplot(1, 3, 2)\not.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT squared Euclidean')\n\npl.subplot(1, 3, 3)\not.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT sqrt Euclidean')\npl.tight_layout()\n\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Dataset 2 : Partial circle\n--------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n = 50 # nb samples\nxtot = np.zeros((n + 1, 2))\nxtot[:, 0] = np.cos(\n (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)\nxtot[:, 1] = np.sin(\n (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)\n\nxs = xtot[:n, :]\nxt = xtot[1:, :]\n\na, b = ot.unif(n), ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM1 = ot.dist(xs, xt, metric='euclidean')\nM1 /= M1.max()\n\n# loss matrix\nM2 = ot.dist(xs, xt, metric='sqeuclidean')\nM2 /= M2.max()\n\n# loss matrix\nMp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))\nMp /= Mp.max()\n\n\n# Data\npl.figure(4, figsize=(7, 3))\npl.clf()\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\npl.title('Source and traget distributions')\n\n\n# Cost matrices\npl.figure(5, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\npl.imshow(M1, interpolation='nearest')\npl.title('Euclidean cost')\n\npl.subplot(1, 3, 2)\npl.imshow(M2, interpolation='nearest')\npl.title('Squared Euclidean cost')\n\npl.subplot(1, 3, 3)\npl.imshow(Mp, interpolation='nearest')\npl.title('Sqrt Euclidean cost')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Dataset 2 : Plot OT Matrices\n-----------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% EMD\nG1 = ot.emd(a, b, M1)\nG2 = ot.emd(a, b, M2)\nGp = ot.emd(a, b, Mp)\n\n# OT matrices\npl.figure(6, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\not.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT Euclidean')\n\npl.subplot(1, 3, 2)\not.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT squared Euclidean')\n\npl.subplot(1, 3, 3)\not.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT sqrt Euclidean')\npl.tight_layout()\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.py b/docs/source/auto_examples/plot_OT_L1_vs_L2.py
new file mode 100644
index 0000000..37b429f
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+"""
+==========================================
+2D Optimal transport for different metrics
+==========================================
+
+2D OT on empirical distributio with different gound metric.
+
+Stole the figure idea from Fig. 1 and 2 in
+https://arxiv.org/pdf/1706.07650.pdf
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Dataset 1 : uniform sampling
+# ----------------------------
+
+n = 20 # nb samples
+xs = np.zeros((n, 2))
+xs[:, 0] = np.arange(n) + 1
+xs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...
+
+xt = np.zeros((n, 2))
+xt[:, 1] = np.arange(n) + 1
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M1 = ot.dist(xs, xt, metric='euclidean')
+M1 /= M1.max()
+
+# loss matrix
+M2 = ot.dist(xs, xt, metric='sqeuclidean')
+M2 /= M2.max()
+
+# loss matrix
+Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp /= Mp.max()
+
+# Data
+pl.figure(1, figsize=(7, 3))
+pl.clf()
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+pl.title('Source and target distributions')
+
+
+# Cost matrices
+pl.figure(2, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+pl.imshow(M1, interpolation='nearest')
+pl.title('Euclidean cost')
+
+pl.subplot(1, 3, 2)
+pl.imshow(M2, interpolation='nearest')
+pl.title('Squared Euclidean cost')
+
+pl.subplot(1, 3, 3)
+pl.imshow(Mp, interpolation='nearest')
+pl.title('Sqrt Euclidean cost')
+pl.tight_layout()
+
+##############################################################################
+# Dataset 1 : Plot OT Matrices
+# ----------------------------
+
+
+#%% EMD
+G1 = ot.emd(a, b, M1)
+G2 = ot.emd(a, b, M2)
+Gp = ot.emd(a, b, Mp)
+
+# OT matrices
+pl.figure(3, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT Euclidean')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT squared Euclidean')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT sqrt Euclidean')
+pl.tight_layout()
+
+pl.show()
+
+
+##############################################################################
+# Dataset 2 : Partial circle
+# --------------------------
+
+n = 50 # nb samples
+xtot = np.zeros((n + 1, 2))
+xtot[:, 0] = np.cos(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+xtot[:, 1] = np.sin(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+
+xs = xtot[:n, :]
+xt = xtot[1:, :]
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M1 = ot.dist(xs, xt, metric='euclidean')
+M1 /= M1.max()
+
+# loss matrix
+M2 = ot.dist(xs, xt, metric='sqeuclidean')
+M2 /= M2.max()
+
+# loss matrix
+Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp /= Mp.max()
+
+
+# Data
+pl.figure(4, figsize=(7, 3))
+pl.clf()
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+pl.title('Source and traget distributions')
+
+
+# Cost matrices
+pl.figure(5, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+pl.imshow(M1, interpolation='nearest')
+pl.title('Euclidean cost')
+
+pl.subplot(1, 3, 2)
+pl.imshow(M2, interpolation='nearest')
+pl.title('Squared Euclidean cost')
+
+pl.subplot(1, 3, 3)
+pl.imshow(Mp, interpolation='nearest')
+pl.title('Sqrt Euclidean cost')
+pl.tight_layout()
+
+##############################################################################
+# Dataset 2 : Plot OT Matrices
+# -----------------------------
+
+
+#%% EMD
+G1 = ot.emd(a, b, M1)
+G2 = ot.emd(a, b, M2)
+Gp = ot.emd(a, b, Mp)
+
+# OT matrices
+pl.figure(6, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT Euclidean')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT squared Euclidean')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT sqrt Euclidean')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.rst b/docs/source/auto_examples/plot_OT_L1_vs_L2.rst
new file mode 100644
index 0000000..5db4b55
--- /dev/null
+++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.rst
@@ -0,0 +1,318 @@
+
+
+.. _sphx_glr_auto_examples_plot_OT_L1_vs_L2.py:
+
+
+==========================================
+2D Optimal transport for different metrics
+==========================================
+
+2D OT on empirical distributio with different gound metric.
+
+Stole the figure idea from Fig. 1 and 2 in
+https://arxiv.org/pdf/1706.07650.pdf
+
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ import ot.plot
+
+
+
+
+
+
+
+Dataset 1 : uniform sampling
+----------------------------
+
+
+
+.. code-block:: python
+
+
+ n = 20 # nb samples
+ xs = np.zeros((n, 2))
+ xs[:, 0] = np.arange(n) + 1
+ xs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...
+
+ xt = np.zeros((n, 2))
+ xt[:, 1] = np.arange(n) + 1
+
+ a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+ # loss matrix
+ M1 = ot.dist(xs, xt, metric='euclidean')
+ M1 /= M1.max()
+
+ # loss matrix
+ M2 = ot.dist(xs, xt, metric='sqeuclidean')
+ M2 /= M2.max()
+
+ # loss matrix
+ Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+ Mp /= Mp.max()
+
+ # Data
+ pl.figure(1, figsize=(7, 3))
+ pl.clf()
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ pl.title('Source and target distributions')
+
+
+ # Cost matrices
+ pl.figure(2, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ pl.imshow(M1, interpolation='nearest')
+ pl.title('Euclidean cost')
+
+ pl.subplot(1, 3, 2)
+ pl.imshow(M2, interpolation='nearest')
+ pl.title('Squared Euclidean cost')
+
+ pl.subplot(1, 3, 3)
+ pl.imshow(Mp, interpolation='nearest')
+ pl.title('Sqrt Euclidean cost')
+ pl.tight_layout()
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png
+ :scale: 47
+
+
+
+
+Dataset 1 : Plot OT Matrices
+----------------------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% EMD
+ G1 = ot.emd(a, b, M1)
+ G2 = ot.emd(a, b, M2)
+ Gp = ot.emd(a, b, Mp)
+
+ # OT matrices
+ pl.figure(3, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT Euclidean')
+
+ pl.subplot(1, 3, 2)
+ ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT squared Euclidean')
+
+ pl.subplot(1, 3, 3)
+ ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT sqrt Euclidean')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png
+ :align: center
+
+
+
+
+Dataset 2 : Partial circle
+--------------------------
+
+
+
+.. code-block:: python
+
+
+ n = 50 # nb samples
+ xtot = np.zeros((n + 1, 2))
+ xtot[:, 0] = np.cos(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ xtot[:, 1] = np.sin(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+
+ xs = xtot[:n, :]
+ xt = xtot[1:, :]
+
+ a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+ # loss matrix
+ M1 = ot.dist(xs, xt, metric='euclidean')
+ M1 /= M1.max()
+
+ # loss matrix
+ M2 = ot.dist(xs, xt, metric='sqeuclidean')
+ M2 /= M2.max()
+
+ # loss matrix
+ Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+ Mp /= Mp.max()
+
+
+ # Data
+ pl.figure(4, figsize=(7, 3))
+ pl.clf()
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ pl.title('Source and traget distributions')
+
+
+ # Cost matrices
+ pl.figure(5, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ pl.imshow(M1, interpolation='nearest')
+ pl.title('Euclidean cost')
+
+ pl.subplot(1, 3, 2)
+ pl.imshow(M2, interpolation='nearest')
+ pl.title('Squared Euclidean cost')
+
+ pl.subplot(1, 3, 3)
+ pl.imshow(Mp, interpolation='nearest')
+ pl.title('Sqrt Euclidean cost')
+ pl.tight_layout()
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.png
+ :scale: 47
+
+
+
+
+Dataset 2 : Plot OT Matrices
+-----------------------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% EMD
+ G1 = ot.emd(a, b, M1)
+ G2 = ot.emd(a, b, M2)
+ Gp = ot.emd(a, b, Mp)
+
+ # OT matrices
+ pl.figure(6, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT Euclidean')
+
+ pl.subplot(1, 3, 2)
+ ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT squared Euclidean')
+
+ pl.subplot(1, 3, 3)
+ ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT sqrt Euclidean')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.958 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_OT_L1_vs_L2.py <plot_OT_L1_vs_L2.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_OT_L1_vs_L2.ipynb <plot_OT_L1_vs_L2.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_UOT_1D.ipynb b/docs/source/auto_examples/plot_UOT_1D.ipynb
new file mode 100644
index 0000000..c695306
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_1D.ipynb
@@ -0,0 +1,108 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 1D Unbalanced optimal transport\n\n\nThis example illustrates the computation of Unbalanced Optimal transport\nusing a Kullback-Leibler relaxation.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Hicham Janati <hicham.janati@inria.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot\nfrom ot.datasets import make_1D_gauss as gauss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# make distributions unbalanced\nb *= 5.\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot distributions and loss matrix\n----------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()\n\n# plot distributions and loss matrix\n\npl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve Unbalanced Sinkhorn\n--------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Sinkhorn\n\nepsilon = 0.1 # entropy parameter\nalpha = 1. # Unbalanced KL relaxation parameter\nGs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_UOT_1D.py b/docs/source/auto_examples/plot_UOT_1D.py
new file mode 100644
index 0000000..2ea8b05
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_1D.py
@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+"""
+===============================
+1D Unbalanced optimal transport
+===============================
+
+This example illustrates the computation of Unbalanced Optimal transport
+using a Kullback-Leibler relaxation.
+"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+from ot.datasets import make_1D_gauss as gauss
+
+##############################################################################
+# Generate data
+# -------------
+
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
+
+# make distributions unbalanced
+b *= 5.
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
+pl.legend()
+
+# plot distributions and loss matrix
+
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+
+##############################################################################
+# Solve Unbalanced Sinkhorn
+# --------------
+
+
+# Sinkhorn
+
+epsilon = 0.1 # entropy parameter
+alpha = 1. # Unbalanced KL relaxation parameter
+Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_UOT_1D.rst b/docs/source/auto_examples/plot_UOT_1D.rst
new file mode 100644
index 0000000..8e618b4
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_1D.rst
@@ -0,0 +1,173 @@
+
+
+.. _sphx_glr_auto_examples_plot_UOT_1D.py:
+
+
+===============================
+1D Unbalanced optimal transport
+===============================
+
+This example illustrates the computation of Unbalanced Optimal transport
+using a Kullback-Leibler relaxation.
+
+
+
+.. code-block:: python
+
+
+ # Author: Hicham Janati <hicham.janati@inria.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ import ot.plot
+ from ot.datasets import make_1D_gauss as gauss
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=20, s=5) # m= mean, s= std
+ b = gauss(n, m=60, s=10)
+
+ # make distributions unbalanced
+ b *= 5.
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ M /= M.max()
+
+
+
+
+
+
+
+
+Plot distributions and loss matrix
+----------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ pl.plot(x, a, 'b', label='Source distribution')
+ pl.plot(x, b, 'r', label='Target distribution')
+ pl.legend()
+
+ # plot distributions and loss matrix
+
+ pl.figure(2, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_UOT_1D_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_UOT_1D_002.png
+ :scale: 47
+
+
+
+
+Solve Unbalanced Sinkhorn
+--------------
+
+
+
+.. code-block:: python
+
+
+
+ # Sinkhorn
+
+ epsilon = 0.1 # entropy parameter
+ alpha = 1. # Unbalanced KL relaxation parameter
+ Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_UOT_1D_006.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Err
+ -------------------
+ 0|1.838786e+00|
+ 10|1.242379e-01|
+ 20|2.581314e-03|
+ 30|5.674552e-05|
+ 40|1.252959e-06|
+ 50|2.768136e-08|
+ 60|6.116090e-10|
+
+
+**Total running time of the script:** ( 0 minutes 0.259 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_UOT_1D.py <plot_UOT_1D.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_UOT_1D.ipynb <plot_UOT_1D.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb b/docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb
new file mode 100644
index 0000000..e59cdc2
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_barycenter_1D.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 1D Wasserstein barycenter demo for Unbalanced distributions\n\n\nThis example illustrates the computation of regularized Wassersyein Barycenter\nas proposed in [10] for Unbalanced inputs.\n\n\n[10] Chizat, L., Peyr\u00e9, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Hicham Janati <hicham.janati@inria.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# make unbalanced dists\na2 *= 3.\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Barycenter computation\n----------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# non weighted barycenter computation\n\nweight = 0.5 # 0<=weight<=1\nweights = np.array([1 - weight, weight])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\nalpha = 1.\n\nbary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Barycentric interpolation\n-------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# barycenter interpolation\n\nn_weight = 11\nweight_list = np.linspace(0, 1, n_weight)\n\n\nB_l2 = np.zeros((n, n_weight))\n\nB_wass = np.copy(B_l2)\n\nfor i in range(0, n_weight):\n weight = weight_list[i]\n weights = np.array([1 - weight, weight])\n B_l2[:, i] = A.dot(weights)\n B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)\n\n\n# plot interpolation\n\npl.figure(3)\n\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = weight_list\nfor i, z in enumerate(zs):\n ys = B_l2[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel(r'$\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with l2')\npl.tight_layout()\n\npl.figure(4)\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = weight_list\nfor i, z in enumerate(zs):\n ys = B_wass[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel(r'$\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with Wasserstein')\npl.tight_layout()\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_UOT_barycenter_1D.py b/docs/source/auto_examples/plot_UOT_barycenter_1D.py
new file mode 100644
index 0000000..c8d9d3b
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_barycenter_1D.py
@@ -0,0 +1,164 @@
+# -*- coding: utf-8 -*-
+"""
+===========================================================
+1D Wasserstein barycenter demo for Unbalanced distributions
+===========================================================
+
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [10] for Unbalanced inputs.
+
+
+[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+"""
+
+# Author: Hicham Janati <hicham.janati@inria.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+from matplotlib.collections import PolyCollection
+
+##############################################################################
+# Generate data
+# -------------
+
+# parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# make unbalanced dists
+a2 *= 3.
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+# plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+# non weighted barycenter computation
+
+weight = 0.5 # 0<=weight<=1
+weights = np.array([1 - weight, weight])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+alpha = 1.
+
+bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+##############################################################################
+# Barycentric interpolation
+# -------------------------
+
+# barycenter interpolation
+
+n_weight = 11
+weight_list = np.linspace(0, 1, n_weight)
+
+
+B_l2 = np.zeros((n, n_weight))
+
+B_wass = np.copy(B_l2)
+
+for i in range(0, n_weight):
+ weight = weight_list[i]
+ weights = np.array([1 - weight, weight])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+
+# plot interpolation
+
+pl.figure(3)
+
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = weight_list
+for i, z in enumerate(zs):
+ ys = B_l2[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel(r'$\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with l2')
+pl.tight_layout()
+
+pl.figure(4)
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = weight_list
+for i, z in enumerate(zs):
+ ys = B_wass[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel(r'$\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with Wasserstein')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_UOT_barycenter_1D.rst b/docs/source/auto_examples/plot_UOT_barycenter_1D.rst
new file mode 100644
index 0000000..ac17587
--- /dev/null
+++ b/docs/source/auto_examples/plot_UOT_barycenter_1D.rst
@@ -0,0 +1,261 @@
+
+
+.. _sphx_glr_auto_examples_plot_UOT_barycenter_1D.py:
+
+
+===========================================================
+1D Wasserstein barycenter demo for Unbalanced distributions
+===========================================================
+
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [10] for Unbalanced inputs.
+
+
+[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Hicham Janati <hicham.janati@inria.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ # necessary for 3d plot even if not used
+ from mpl_toolkits.mplot3d import Axes3D # noqa
+ from matplotlib.collections import PolyCollection
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ # parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+ # make unbalanced dists
+ a2 *= 3.
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ n_distributions = A.shape[1]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ # plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_001.png
+ :align: center
+
+
+
+
+Barycenter computation
+----------------------
+
+
+
+.. code-block:: python
+
+
+ # non weighted barycenter computation
+
+ weight = 0.5 # 0<=weight<=1
+ weights = np.array([1 - weight, weight])
+
+ # l2bary
+ bary_l2 = A.dot(weights)
+
+ # wasserstein
+ reg = 1e-3
+ alpha = 1.
+
+ bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+ pl.figure(2)
+ pl.clf()
+ pl.subplot(2, 1, 1)
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+
+ pl.subplot(2, 1, 2)
+ pl.plot(x, bary_l2, 'r', label='l2')
+ pl.plot(x, bary_wass, 'g', label='Wasserstein')
+ pl.legend()
+ pl.title('Barycenters')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_003.png
+ :align: center
+
+
+
+
+Barycentric interpolation
+-------------------------
+
+
+
+.. code-block:: python
+
+
+ # barycenter interpolation
+
+ n_weight = 11
+ weight_list = np.linspace(0, 1, n_weight)
+
+
+ B_l2 = np.zeros((n, n_weight))
+
+ B_wass = np.copy(B_l2)
+
+ for i in range(0, n_weight):
+ weight = weight_list[i]
+ weights = np.array([1 - weight, weight])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
+
+
+ # plot interpolation
+
+ pl.figure(3)
+
+ cmap = pl.cm.get_cmap('viridis')
+ verts = []
+ zs = weight_list
+ for i, z in enumerate(zs):
+ ys = B_l2[:, i]
+ verts.append(list(zip(x, ys)))
+
+ ax = pl.gcf().gca(projection='3d')
+
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+ poly.set_alpha(0.7)
+ ax.add_collection3d(poly, zs=zs, zdir='y')
+ ax.set_xlabel('x')
+ ax.set_xlim3d(0, n)
+ ax.set_ylabel(r'$\alpha$')
+ ax.set_ylim3d(0, 1)
+ ax.set_zlabel('')
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
+ pl.title('Barycenter interpolation with l2')
+ pl.tight_layout()
+
+ pl.figure(4)
+ cmap = pl.cm.get_cmap('viridis')
+ verts = []
+ zs = weight_list
+ for i, z in enumerate(zs):
+ ys = B_wass[:, i]
+ verts.append(list(zip(x, ys)))
+
+ ax = pl.gcf().gca(projection='3d')
+
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
+ poly.set_alpha(0.7)
+ ax.add_collection3d(poly, zs=zs, zdir='y')
+ ax.set_xlabel('x')
+ ax.set_xlim3d(0, n)
+ ax.set_ylabel(r'$\alpha$')
+ ax.set_ylim3d(0, 1)
+ ax.set_zlabel('')
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
+ pl.title('Barycenter interpolation with Wasserstein')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_005.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_UOT_barycenter_1D_006.png
+ :scale: 47
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.344 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_UOT_barycenter_1D.py <plot_UOT_barycenter_1D.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_UOT_barycenter_1D.ipynb <plot_UOT_barycenter_1D.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_WDA.ipynb b/docs/source/auto_examples/plot_WDA.ipynb
new file mode 100644
index 0000000..1661c53
--- /dev/null
+++ b/docs/source/auto_examples/plot_WDA.ipynb
@@ -0,0 +1,144 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# Wasserstein Discriminant Analysis\n\n\nThis example illustrate the use of WDA as proposed in [11].\n\n\n[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).\nWasserstein Discriminant Analysis.\n\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\n\nfrom ot.dr import wda, fda"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% parameters\n\nn = 1000 # nb samples in source and target datasets\nnz = 0.2\n\n# generate circle dataset\nt = np.random.rand(n) * 2 * np.pi\nys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxs = np.concatenate(\n (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nt = np.random.rand(n) * 2 * np.pi\nyt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxt = np.concatenate(\n (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nnbnoise = 8\n\nxs = np.hstack((xs, np.random.randn(n, nbnoise)))\nxt = np.hstack((xt, np.random.randn(n, nbnoise)))"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot data\n---------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% plot samples\npl.figure(1, figsize=(6.4, 3.5))\n\npl.subplot(1, 2, 1)\npl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')\npl.legend(loc=0)\npl.title('Discriminant dimensions')\n\npl.subplot(1, 2, 2)\npl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')\npl.legend(loc=0)\npl.title('Other dimensions')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute Fisher Discriminant Analysis\n------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% Compute FDA\np = 2\n\nPfda, projfda = fda(xs, ys, p)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute Wasserstein Discriminant Analysis\n-----------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% Compute WDA\np = 2\nreg = 1e0\nk = 10\nmaxiter = 100\n\nPwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot 2D projections\n-------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% plot samples\n\nxsp = projfda(xs)\nxtp = projfda(xt)\n\nxspw = projwda(xs)\nxtpw = projwda(xt)\n\npl.figure(2)\n\npl.subplot(2, 2, 1)\npl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples FDA')\n\npl.subplot(2, 2, 2)\npl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples FDA')\n\npl.subplot(2, 2, 3)\npl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples WDA')\n\npl.subplot(2, 2, 4)\npl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples WDA')\npl.tight_layout()\n\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_WDA.py b/docs/source/auto_examples/plot_WDA.py
new file mode 100644
index 0000000..93cc237
--- /dev/null
+++ b/docs/source/auto_examples/plot_WDA.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Wasserstein Discriminant Analysis
+=================================
+
+This example illustrate the use of WDA as proposed in [11].
+
+
+[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+Wasserstein Discriminant Analysis.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+
+from ot.dr import wda, fda
+
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 1000 # nb samples in source and target datasets
+nz = 0.2
+
+# generate circle dataset
+t = np.random.rand(n) * 2 * np.pi
+ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xs = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
+
+t = np.random.rand(n) * 2 * np.pi
+yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xt = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
+
+nbnoise = 8
+
+xs = np.hstack((xs, np.random.randn(n, nbnoise)))
+xt = np.hstack((xt, np.random.randn(n, nbnoise)))
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot samples
+pl.figure(1, figsize=(6.4, 3.5))
+
+pl.subplot(1, 2, 1)
+pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
+pl.legend(loc=0)
+pl.title('Discriminant dimensions')
+
+pl.subplot(1, 2, 2)
+pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
+pl.legend(loc=0)
+pl.title('Other dimensions')
+pl.tight_layout()
+
+##############################################################################
+# Compute Fisher Discriminant Analysis
+# ------------------------------------
+
+#%% Compute FDA
+p = 2
+
+Pfda, projfda = fda(xs, ys, p)
+
+##############################################################################
+# Compute Wasserstein Discriminant Analysis
+# -----------------------------------------
+
+#%% Compute WDA
+p = 2
+reg = 1e0
+k = 10
+maxiter = 100
+
+Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
+
+
+##############################################################################
+# Plot 2D projections
+# -------------------
+
+#%% plot samples
+
+xsp = projfda(xs)
+xtp = projfda(xt)
+
+xspw = projwda(xs)
+xtpw = projwda(xt)
+
+pl.figure(2)
+
+pl.subplot(2, 2, 1)
+pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected training samples FDA')
+
+pl.subplot(2, 2, 2)
+pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected test samples FDA')
+
+pl.subplot(2, 2, 3)
+pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected training samples WDA')
+
+pl.subplot(2, 2, 4)
+pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected test samples WDA')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_WDA.rst b/docs/source/auto_examples/plot_WDA.rst
new file mode 100644
index 0000000..2d83123
--- /dev/null
+++ b/docs/source/auto_examples/plot_WDA.rst
@@ -0,0 +1,244 @@
+
+
+.. _sphx_glr_auto_examples_plot_WDA.py:
+
+
+=================================
+Wasserstein Discriminant Analysis
+=================================
+
+This example illustrate the use of WDA as proposed in [11].
+
+
+[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+Wasserstein Discriminant Analysis.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+
+ from ot.dr import wda, fda
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters
+
+ n = 1000 # nb samples in source and target datasets
+ nz = 0.2
+
+ # generate circle dataset
+ t = np.random.rand(n) * 2 * np.pi
+ ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ xs = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+ xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
+
+ t = np.random.rand(n) * 2 * np.pi
+ yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ xt = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+ xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
+
+ nbnoise = 8
+
+ xs = np.hstack((xs, np.random.randn(n, nbnoise)))
+ xt = np.hstack((xt, np.random.randn(n, nbnoise)))
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% plot samples
+ pl.figure(1, figsize=(6.4, 3.5))
+
+ pl.subplot(1, 2, 1)
+ pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
+ pl.legend(loc=0)
+ pl.title('Discriminant dimensions')
+
+ pl.subplot(1, 2, 2)
+ pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
+ pl.legend(loc=0)
+ pl.title('Other dimensions')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png
+ :align: center
+
+
+
+
+Compute Fisher Discriminant Analysis
+------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Compute FDA
+ p = 2
+
+ Pfda, projfda = fda(xs, ys, p)
+
+
+
+
+
+
+
+Compute Wasserstein Discriminant Analysis
+-----------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Compute WDA
+ p = 2
+ reg = 1e0
+ k = 10
+ maxiter = 100
+
+ Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
+
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ Compiling cost function...
+ Computing gradient of cost function...
+ iter cost val grad. norm
+ 1 +9.0167295050534191e-01 2.28422652e-01
+ 2 +4.8324990550878105e-01 4.89362707e-01
+ 3 +3.4613154515357075e-01 2.84117562e-01
+ 4 +2.5277108387195002e-01 1.24888750e-01
+ 5 +2.4113858393736629e-01 8.07491482e-02
+ 6 +2.3642108593032782e-01 1.67612140e-02
+ 7 +2.3625721372202199e-01 7.68640008e-03
+ 8 +2.3625461994913738e-01 7.42200784e-03
+ 9 +2.3624493441436939e-01 6.43534105e-03
+ 10 +2.3621901383686217e-01 2.17960585e-03
+ 11 +2.3621854258326572e-01 2.03306749e-03
+ 12 +2.3621696458678049e-01 1.37118721e-03
+ 13 +2.3621569489873540e-01 2.76368907e-04
+ 14 +2.3621565599232983e-01 1.41898134e-04
+ 15 +2.3621564465487518e-01 5.96602069e-05
+ 16 +2.3621564232556647e-01 1.08709521e-05
+ 17 +2.3621564230277003e-01 9.17855656e-06
+ 18 +2.3621564224857586e-01 1.73728345e-06
+ 19 +2.3621564224748123e-01 1.17770019e-06
+ 20 +2.3621564224658587e-01 2.16179383e-07
+ Terminated - min grad norm reached after 20 iterations, 9.20 seconds.
+
+
+Plot 2D projections
+-------------------
+
+
+
+.. code-block:: python
+
+
+ #%% plot samples
+
+ xsp = projfda(xs)
+ xtp = projfda(xt)
+
+ xspw = projwda(xs)
+ xtpw = projwda(xt)
+
+ pl.figure(2)
+
+ pl.subplot(2, 2, 1)
+ pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected training samples FDA')
+
+ pl.subplot(2, 2, 2)
+ pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected test samples FDA')
+
+ pl.subplot(2, 2, 3)
+ pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected training samples WDA')
+
+ pl.subplot(2, 2, 4)
+ pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected test samples WDA')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_WDA_003.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 16.182 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_WDA.py <plot_WDA.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_WDA.ipynb <plot_WDA.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_barycenter_1D.ipynb b/docs/source/auto_examples/plot_barycenter_1D.ipynb
new file mode 100644
index 0000000..fc60e1f
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_1D.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 1D Wasserstein barycenter demo\n\n\nThis example illustrates the computation of regularized Wassersyein Barycenter\nas proposed in [3].\n\n\n[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyr\u00e9, G. (2015).\nIterative Bregman projections for regularized transportation problems\nSIAM Journal on Scientific Computing, 37(2), A1111-A1138.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Barycenter computation\n----------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% barycenter computation\n\nalpha = 0.2 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Barycentric interpolation\n-------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% barycenter interpolation\n\nn_alpha = 11\nalpha_list = np.linspace(0, 1, n_alpha)\n\n\nB_l2 = np.zeros((n, n_alpha))\n\nB_wass = np.copy(B_l2)\n\nfor i in range(0, n_alpha):\n alpha = alpha_list[i]\n weights = np.array([1 - alpha, alpha])\n B_l2[:, i] = A.dot(weights)\n B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)\n\n#%% plot interpolation\n\npl.figure(3)\n\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = alpha_list\nfor i, z in enumerate(zs):\n ys = B_l2[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel('$\\\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with l2')\npl.tight_layout()\n\npl.figure(4)\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = alpha_list\nfor i, z in enumerate(zs):\n ys = B_wass[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel('$\\\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with Wasserstein')\npl.tight_layout()\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_barycenter_1D.py b/docs/source/auto_examples/plot_barycenter_1D.py
new file mode 100644
index 0000000..6864301
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_1D.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+1D Wasserstein barycenter demo
+==============================
+
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [3].
+
+
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+from matplotlib.collections import PolyCollection
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+#%% barycenter computation
+
+alpha = 0.2 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+##############################################################################
+# Barycentric interpolation
+# -------------------------
+
+#%% barycenter interpolation
+
+n_alpha = 11
+alpha_list = np.linspace(0, 1, n_alpha)
+
+
+B_l2 = np.zeros((n, n_alpha))
+
+B_wass = np.copy(B_l2)
+
+for i in range(0, n_alpha):
+ alpha = alpha_list[i]
+ weights = np.array([1 - alpha, alpha])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
+
+#%% plot interpolation
+
+pl.figure(3)
+
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = alpha_list
+for i, z in enumerate(zs):
+ ys = B_l2[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel('$\\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with l2')
+pl.tight_layout()
+
+pl.figure(4)
+cmap = pl.cm.get_cmap('viridis')
+verts = []
+zs = alpha_list
+for i, z in enumerate(zs):
+ ys = B_wass[:, i]
+ verts.append(list(zip(x, ys)))
+
+ax = pl.gcf().gca(projection='3d')
+
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
+poly.set_alpha(0.7)
+ax.add_collection3d(poly, zs=zs, zdir='y')
+ax.set_xlabel('x')
+ax.set_xlim3d(0, n)
+ax.set_ylabel('$\\alpha$')
+ax.set_ylim3d(0, 1)
+ax.set_zlabel('')
+ax.set_zlim3d(0, B_l2.max() * 1.01)
+pl.title('Barycenter interpolation with Wasserstein')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_barycenter_1D.rst b/docs/source/auto_examples/plot_barycenter_1D.rst
new file mode 100644
index 0000000..66ac042
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_1D.rst
@@ -0,0 +1,257 @@
+
+
+.. _sphx_glr_auto_examples_plot_barycenter_1D.py:
+
+
+==============================
+1D Wasserstein barycenter demo
+==============================
+
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [3].
+
+
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ # necessary for 3d plot even if not used
+ from mpl_toolkits.mplot3d import Axes3D # noqa
+ from matplotlib.collections import PolyCollection
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ n_distributions = A.shape[1]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
+ :align: center
+
+
+
+
+Barycenter computation
+----------------------
+
+
+
+.. code-block:: python
+
+
+ #%% barycenter computation
+
+ alpha = 0.2 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # l2bary
+ bary_l2 = A.dot(weights)
+
+ # wasserstein
+ reg = 1e-3
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+
+ pl.figure(2)
+ pl.clf()
+ pl.subplot(2, 1, 1)
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+
+ pl.subplot(2, 1, 2)
+ pl.plot(x, bary_l2, 'r', label='l2')
+ pl.plot(x, bary_wass, 'g', label='Wasserstein')
+ pl.legend()
+ pl.title('Barycenters')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
+ :align: center
+
+
+
+
+Barycentric interpolation
+-------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% barycenter interpolation
+
+ n_alpha = 11
+ alpha_list = np.linspace(0, 1, n_alpha)
+
+
+ B_l2 = np.zeros((n, n_alpha))
+
+ B_wass = np.copy(B_l2)
+
+ for i in range(0, n_alpha):
+ alpha = alpha_list[i]
+ weights = np.array([1 - alpha, alpha])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
+
+ #%% plot interpolation
+
+ pl.figure(3)
+
+ cmap = pl.cm.get_cmap('viridis')
+ verts = []
+ zs = alpha_list
+ for i, z in enumerate(zs):
+ ys = B_l2[:, i]
+ verts.append(list(zip(x, ys)))
+
+ ax = pl.gcf().gca(projection='3d')
+
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
+ poly.set_alpha(0.7)
+ ax.add_collection3d(poly, zs=zs, zdir='y')
+ ax.set_xlabel('x')
+ ax.set_xlim3d(0, n)
+ ax.set_ylabel('$\\alpha$')
+ ax.set_ylim3d(0, 1)
+ ax.set_zlabel('')
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
+ pl.title('Barycenter interpolation with l2')
+ pl.tight_layout()
+
+ pl.figure(4)
+ cmap = pl.cm.get_cmap('viridis')
+ verts = []
+ zs = alpha_list
+ for i, z in enumerate(zs):
+ ys = B_wass[:, i]
+ verts.append(list(zip(x, ys)))
+
+ ax = pl.gcf().gca(projection='3d')
+
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
+ poly.set_alpha(0.7)
+ ax.add_collection3d(poly, zs=zs, zdir='y')
+ ax.set_xlabel('x')
+ ax.set_xlim3d(0, n)
+ ax.set_ylabel('$\\alpha$')
+ ax.set_ylim3d(0, 1)
+ ax.set_zlabel('')
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
+ pl.title('Barycenter interpolation with Wasserstein')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_005.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_006.png
+ :scale: 47
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.413 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_barycenter_1D.py <plot_barycenter_1D.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_barycenter_1D.ipynb <plot_barycenter_1D.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_barycenter_fgw.ipynb b/docs/source/auto_examples/plot_barycenter_fgw.ipynb
new file mode 100644
index 0000000..28229b2
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_fgw.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n=================================\nPlot graphs' barycenter using FGW\n=================================\n\nThis example illustrates the computation barycenter of labeled graphs using FGW\n\nRequires networkx >=2\n\n.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain\n and Courty Nicolas\n \"Optimal Transport for structured data with application on graphs\"\n International Conference on Machine Learning (ICML). 2019.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Titouan Vayer <titouan.vayer@irisa.fr>\n#\n# License: MIT License\n\n#%% load libraries\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport networkx as nx\nimport math\nfrom scipy.sparse.csgraph import shortest_path\nimport matplotlib.colors as mcol\nfrom matplotlib import cm\nfrom ot.gromov import fgw_barycenters\n#%% Graph functions\n\n\ndef find_thresh(C, inf=0.5, sup=3, step=10):\n \"\"\" Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected\n Tthe threshold is found by a linesearch between values \"inf\" and \"sup\" with \"step\" thresholds tested.\n The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix\n and the original matrix.\n Parameters\n ----------\n C : ndarray, shape (n_nodes,n_nodes)\n The structure matrix to threshold\n inf : float\n The beginning of the linesearch\n sup : float\n The end of the linesearch\n step : integer\n Number of thresholds tested\n \"\"\"\n dist = []\n search = np.linspace(inf, sup, step)\n for thresh in search:\n Cprime = sp_to_adjency(C, 0, thresh)\n SC = shortest_path(Cprime, method='D')\n SC[SC == float('inf')] = 100\n dist.append(np.linalg.norm(SC - C))\n return search[np.argmin(dist)], dist\n\n\ndef sp_to_adjency(C, threshinf=0.2, threshsup=1.8):\n \"\"\" Thresholds the structure matrix in order to compute an adjency matrix.\n All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0\n Parameters\n ----------\n C : ndarray, shape (n_nodes,n_nodes)\n The structure matrix to threshold\n threshinf : float\n The minimum value of distance from which the new value is set to 1\n threshsup : float\n The maximum value of distance from which the new value is set to 1\n Returns\n -------\n C : ndarray, shape (n_nodes,n_nodes)\n The threshold matrix. Each element is in {0,1}\n \"\"\"\n H = np.zeros_like(C)\n np.fill_diagonal(H, np.diagonal(C))\n C = C - H\n C = np.minimum(np.maximum(C, threshinf), threshsup)\n C[C == threshsup] = 0\n C[C != 0] = 1\n\n return C\n\n\ndef build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):\n \"\"\" Create a noisy circular graph\n \"\"\"\n g = nx.Graph()\n g.add_nodes_from(list(range(N)))\n for i in range(N):\n noise = float(np.random.normal(mu, sigma, 1))\n if with_noise:\n g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)\n else:\n g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))\n g.add_edge(i, i + 1)\n if structure_noise:\n randomint = np.random.randint(0, p)\n if randomint == 0:\n if i <= N - 3:\n g.add_edge(i, i + 2)\n if i == N - 2:\n g.add_edge(i, 0)\n if i == N - 1:\n g.add_edge(i, 1)\n g.add_edge(N, 0)\n noise = float(np.random.normal(mu, sigma, 1))\n if with_noise:\n g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)\n else:\n g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))\n return g\n\n\ndef graph_colors(nx_graph, vmin=0, vmax=7):\n cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)\n cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')\n cpick.set_array([])\n val_map = {}\n for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():\n val_map[k] = cpick.to_rgba(v)\n colors = []\n for node in nx_graph.nodes():\n colors.append(val_map[node])\n return colors"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% circular dataset\n# We build a dataset of noisy circular graphs.\n# Noise is added on the structures by random connections and on the features by gaussian noise.\n\n\nnp.random.seed(30)\nX0 = []\nfor k in range(9):\n X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Plot graphs\n\nplt.figure(figsize=(8, 10))\nfor i in range(len(X0)):\n plt.subplot(3, 3, i + 1)\n g = X0[i]\n pos = nx.kamada_kawai_layout(g)\n nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)\nplt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)\nplt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Barycenter computation\n----------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph\n# Features distances are the euclidean distances\nCs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]\nps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]\nYs = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]\nlambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()\nsizebary = 15 # we choose a barycenter with 15 nodes\n\nA, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot Barycenter\n-------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Create the barycenter\nbary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))\nfor i, v in enumerate(A.ravel()):\n bary.add_node(i, attr_name=v)\n\n#%%\npos = nx.kamada_kawai_layout(bary)\nnx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)\nplt.suptitle('Barycenter', fontsize=20)\nplt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_barycenter_fgw.py b/docs/source/auto_examples/plot_barycenter_fgw.py
new file mode 100644
index 0000000..77b0370
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_fgw.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+"""
+=================================
+Plot graphs' barycenter using FGW
+=================================
+
+This example illustrates the computation barycenter of labeled graphs using FGW
+
+Requires networkx >=2
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+#%% load libraries
+import numpy as np
+import matplotlib.pyplot as plt
+import networkx as nx
+import math
+from scipy.sparse.csgraph import shortest_path
+import matplotlib.colors as mcol
+from matplotlib import cm
+from ot.gromov import fgw_barycenters
+#%% Graph functions
+
+
+def find_thresh(C, inf=0.5, sup=3, step=10):
+ """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
+ Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
+ The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
+ and the original matrix.
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ inf : float
+ The beginning of the linesearch
+ sup : float
+ The end of the linesearch
+ step : integer
+ Number of thresholds tested
+ """
+ dist = []
+ search = np.linspace(inf, sup, step)
+ for thresh in search:
+ Cprime = sp_to_adjency(C, 0, thresh)
+ SC = shortest_path(Cprime, method='D')
+ SC[SC == float('inf')] = 100
+ dist.append(np.linalg.norm(SC - C))
+ return search[np.argmin(dist)], dist
+
+
+def sp_to_adjency(C, threshinf=0.2, threshsup=1.8):
+ """ Thresholds the structure matrix in order to compute an adjency matrix.
+ All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ threshinf : float
+ The minimum value of distance from which the new value is set to 1
+ threshsup : float
+ The maximum value of distance from which the new value is set to 1
+ Returns
+ -------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The threshold matrix. Each element is in {0,1}
+ """
+ H = np.zeros_like(C)
+ np.fill_diagonal(H, np.diagonal(C))
+ C = C - H
+ C = np.minimum(np.maximum(C, threshinf), threshsup)
+ C[C == threshsup] = 0
+ C[C != 0] = 1
+
+ return C
+
+
+def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):
+ """ Create a noisy circular graph
+ """
+ g = nx.Graph()
+ g.add_nodes_from(list(range(N)))
+ for i in range(N):
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
+ else:
+ g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))
+ g.add_edge(i, i + 1)
+ if structure_noise:
+ randomint = np.random.randint(0, p)
+ if randomint == 0:
+ if i <= N - 3:
+ g.add_edge(i, i + 2)
+ if i == N - 2:
+ g.add_edge(i, 0)
+ if i == N - 1:
+ g.add_edge(i, 1)
+ g.add_edge(N, 0)
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
+ else:
+ g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))
+ return g
+
+
+def graph_colors(nx_graph, vmin=0, vmax=7):
+ cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)
+ cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')
+ cpick.set_array([])
+ val_map = {}
+ for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():
+ val_map[k] = cpick.to_rgba(v)
+ colors = []
+ for node in nx_graph.nodes():
+ colors.append(val_map[node])
+ return colors
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% circular dataset
+# We build a dataset of noisy circular graphs.
+# Noise is added on the structures by random connections and on the features by gaussian noise.
+
+
+np.random.seed(30)
+X0 = []
+for k in range(9):
+ X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% Plot graphs
+
+plt.figure(figsize=(8, 10))
+for i in range(len(X0)):
+ plt.subplot(3, 3, i + 1)
+ g = X0[i]
+ pos = nx.kamada_kawai_layout(g)
+ nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)
+plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
+plt.show()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+# Features distances are the euclidean distances
+Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
+ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
+Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
+lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
+sizebary = 15 # we choose a barycenter with 15 nodes
+
+A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)
+
+##############################################################################
+# Plot Barycenter
+# -------------------------
+
+#%% Create the barycenter
+bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
+for i, v in enumerate(A.ravel()):
+ bary.add_node(i, attr_name=v)
+
+#%%
+pos = nx.kamada_kawai_layout(bary)
+nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)
+plt.suptitle('Barycenter', fontsize=20)
+plt.show()
diff --git a/docs/source/auto_examples/plot_barycenter_fgw.rst b/docs/source/auto_examples/plot_barycenter_fgw.rst
new file mode 100644
index 0000000..2c44a65
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_fgw.rst
@@ -0,0 +1,268 @@
+
+
+.. _sphx_glr_auto_examples_plot_barycenter_fgw.py:
+
+
+=================================
+Plot graphs' barycenter using FGW
+=================================
+
+This example illustrates the computation barycenter of labeled graphs using FGW
+
+Requires networkx >=2
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Titouan Vayer <titouan.vayer@irisa.fr>
+ #
+ # License: MIT License
+
+ #%% load libraries
+ import numpy as np
+ import matplotlib.pyplot as plt
+ import networkx as nx
+ import math
+ from scipy.sparse.csgraph import shortest_path
+ import matplotlib.colors as mcol
+ from matplotlib import cm
+ from ot.gromov import fgw_barycenters
+ #%% Graph functions
+
+
+ def find_thresh(C, inf=0.5, sup=3, step=10):
+ """ Trick to find the adequate thresholds from where value of the C matrix are considered close enough to say that nodes are connected
+ Tthe threshold is found by a linesearch between values "inf" and "sup" with "step" thresholds tested.
+ The optimal threshold is the one which minimizes the reconstruction error between the shortest_path matrix coming from the thresholded adjency matrix
+ and the original matrix.
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ inf : float
+ The beginning of the linesearch
+ sup : float
+ The end of the linesearch
+ step : integer
+ Number of thresholds tested
+ """
+ dist = []
+ search = np.linspace(inf, sup, step)
+ for thresh in search:
+ Cprime = sp_to_adjency(C, 0, thresh)
+ SC = shortest_path(Cprime, method='D')
+ SC[SC == float('inf')] = 100
+ dist.append(np.linalg.norm(SC - C))
+ return search[np.argmin(dist)], dist
+
+
+ def sp_to_adjency(C, threshinf=0.2, threshsup=1.8):
+ """ Thresholds the structure matrix in order to compute an adjency matrix.
+ All values between threshinf and threshsup are considered representing connected nodes and set to 1. Else are set to 0
+ Parameters
+ ----------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The structure matrix to threshold
+ threshinf : float
+ The minimum value of distance from which the new value is set to 1
+ threshsup : float
+ The maximum value of distance from which the new value is set to 1
+ Returns
+ -------
+ C : ndarray, shape (n_nodes,n_nodes)
+ The threshold matrix. Each element is in {0,1}
+ """
+ H = np.zeros_like(C)
+ np.fill_diagonal(H, np.diagonal(C))
+ C = C - H
+ C = np.minimum(np.maximum(C, threshinf), threshsup)
+ C[C == threshsup] = 0
+ C[C != 0] = 1
+
+ return C
+
+
+ def build_noisy_circular_graph(N=20, mu=0, sigma=0.3, with_noise=False, structure_noise=False, p=None):
+ """ Create a noisy circular graph
+ """
+ g = nx.Graph()
+ g.add_nodes_from(list(range(N)))
+ for i in range(N):
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
+ else:
+ g.add_node(i, attr_name=math.sin(2 * i * math.pi / N))
+ g.add_edge(i, i + 1)
+ if structure_noise:
+ randomint = np.random.randint(0, p)
+ if randomint == 0:
+ if i <= N - 3:
+ g.add_edge(i, i + 2)
+ if i == N - 2:
+ g.add_edge(i, 0)
+ if i == N - 1:
+ g.add_edge(i, 1)
+ g.add_edge(N, 0)
+ noise = float(np.random.normal(mu, sigma, 1))
+ if with_noise:
+ g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
+ else:
+ g.add_node(N, attr_name=math.sin(2 * N * math.pi / N))
+ return g
+
+
+ def graph_colors(nx_graph, vmin=0, vmax=7):
+ cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)
+ cpick = cm.ScalarMappable(norm=cnorm, cmap='viridis')
+ cpick.set_array([])
+ val_map = {}
+ for k, v in nx.get_node_attributes(nx_graph, 'attr_name').items():
+ val_map[k] = cpick.to_rgba(v)
+ colors = []
+ for node in nx_graph.nodes():
+ colors.append(val_map[node])
+ return colors
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ #%% circular dataset
+ # We build a dataset of noisy circular graphs.
+ # Noise is added on the structures by random connections and on the features by gaussian noise.
+
+
+ np.random.seed(30)
+ X0 = []
+ for k in range(9):
+ X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% Plot graphs
+
+ plt.figure(figsize=(8, 10))
+ for i in range(len(X0)):
+ plt.subplot(3, 3, i + 1)
+ g = X0[i]
+ pos = nx.kamada_kawai_layout(g)
+ nx.draw(g, pos=pos, node_color=graph_colors(g, vmin=-1, vmax=1), with_labels=False, node_size=100)
+ plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
+ plt.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_fgw_001.png
+ :align: center
+
+
+
+
+Barycenter computation
+----------------------
+
+
+
+.. code-block:: python
+
+
+ #%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
+ # Features distances are the euclidean distances
+ Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
+ ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
+ Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
+ lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
+ sizebary = 15 # we choose a barycenter with 15 nodes
+
+ A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95, log=True)
+
+
+
+
+
+
+
+Plot Barycenter
+-------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Create the barycenter
+ bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
+ for i, v in enumerate(A.ravel()):
+ bary.add_node(i, attr_name=v)
+
+ #%%
+ pos = nx.kamada_kawai_layout(bary)
+ nx.draw(bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False)
+ plt.suptitle('Barycenter', fontsize=20)
+ plt.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_fgw_002.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 2.065 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_barycenter_fgw.py <plot_barycenter_fgw.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_barycenter_fgw.ipynb <plot_barycenter_fgw.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.ipynb b/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.ipynb
new file mode 100644
index 0000000..2199162
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.ipynb
@@ -0,0 +1,108 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 1D Wasserstein barycenter comparison between exact LP and entropic regularization\n\n\nThis example illustrates the computation of regularized Wasserstein Barycenter\nas proposed in [3] and exact LP barycenters using standard LP solver.\n\nIt reproduces approximately Figure 3.1 and 3.2 from the following paper:\nCuturi, M., & Peyr\u00e9, G. (2016). A smoothed dual approach for variational\nWasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.\n\n[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyr\u00e9, G. (2015).\nIterative Bregman projections for regularized transportation problems\nSIAM Journal on Scientific Computing, 37(2), A1111-A1138.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection # noqa\n\n#import ot.lp.cvx as cvx"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Gaussian Data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n\nproblems = []\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Dirac Data\n----------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n\na1 = 1.0 * (x > 10) * (x < 50)\na2 = 1.0 * (x > 60) * (x < 80)\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()\n\n#%% parameters\n\na1 = np.zeros(n)\na2 = np.zeros(n)\n\na1[10] = .25\na1[20] = .5\na1[30] = .25\na2[80] = 1\n\n\na1 /= a1.sum()\na2 /= a2.sum()\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()\n\n\n#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()\n\n\n#%% barycenter computation\n\nalpha = 0.5 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\not.tic()\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\not.toc()\n\n\not.tic()\nbary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)\not.toc()\n\n\nproblems.append([A, [bary_l2, bary_wass, bary_wass2]])\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\npl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Final figure\n------------\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot\n\nnbm = len(problems)\nnbm2 = (nbm // 2)\n\n\npl.figure(2, (20, 6))\npl.clf()\n\nfor i in range(nbm):\n\n A = problems[i][0]\n bary_l2 = problems[i][1][0]\n bary_wass = problems[i][1][1]\n bary_wass2 = problems[i][1][2]\n\n pl.subplot(2, nbm, 1 + i)\n for j in range(n_distributions):\n pl.plot(x, A[:, j])\n if i == nbm2:\n pl.title('Distributions')\n pl.xticks(())\n pl.yticks(())\n\n pl.subplot(2, nbm, 1 + i + nbm)\n\n pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')\n pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')\n pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')\n if i == nbm - 1:\n pl.legend()\n if i == nbm2:\n pl.title('Barycenters')\n\n pl.xticks(())\n pl.yticks(())"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.py b/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.py
new file mode 100644
index 0000000..b82765e
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.py
@@ -0,0 +1,281 @@
+# -*- coding: utf-8 -*-
+"""
+=================================================================================
+1D Wasserstein barycenter comparison between exact LP and entropic regularization
+=================================================================================
+
+This example illustrates the computation of regularized Wasserstein Barycenter
+as proposed in [3] and exact LP barycenters using standard LP solver.
+
+It reproduces approximately Figure 3.1 and 3.2 from the following paper:
+Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational
+Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.
+
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
+from matplotlib.collections import PolyCollection # noqa
+
+#import ot.lp.cvx as cvx
+
+##############################################################################
+# Gaussian Data
+# -------------
+
+#%% parameters
+
+problems = []
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+# Gaussian distributions
+a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+#%% barycenter computation
+
+alpha = 0.5 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+ot.tic()
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ot.toc()
+
+
+ot.tic()
+bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ot.toc()
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+##############################################################################
+# Dirac Data
+# ----------
+
+#%% parameters
+
+a1 = 1.0 * (x > 10) * (x < 50)
+a2 = 1.0 * (x > 60) * (x < 80)
+
+a1 /= a1.sum()
+a2 /= a2.sum()
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+
+#%% barycenter computation
+
+alpha = 0.5 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+ot.tic()
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ot.toc()
+
+
+ot.tic()
+bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ot.toc()
+
+
+problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+#%% parameters
+
+a1 = np.zeros(n)
+a2 = np.zeros(n)
+
+a1[10] = .25
+a1[20] = .5
+a1[30] = .25
+a2[80] = 1
+
+
+a1 /= a1.sum()
+a2 /= a2.sum()
+
+# creating matrix A containing all distributions
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
+
+# loss matrix + normalization
+M = ot.utils.dist0(n)
+M /= M.max()
+
+
+#%% plot the distributions
+
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+pl.tight_layout()
+
+
+#%% barycenter computation
+
+alpha = 0.5 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
+
+# l2bary
+bary_l2 = A.dot(weights)
+
+# wasserstein
+reg = 1e-3
+ot.tic()
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ot.toc()
+
+
+ot.tic()
+bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ot.toc()
+
+
+problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+pl.title('Distributions')
+
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+pl.legend()
+pl.title('Barycenters')
+pl.tight_layout()
+
+
+##############################################################################
+# Final figure
+# ------------
+#
+
+#%% plot
+
+nbm = len(problems)
+nbm2 = (nbm // 2)
+
+
+pl.figure(2, (20, 6))
+pl.clf()
+
+for i in range(nbm):
+
+ A = problems[i][0]
+ bary_l2 = problems[i][1][0]
+ bary_wass = problems[i][1][1]
+ bary_wass2 = problems[i][1][2]
+
+ pl.subplot(2, nbm, 1 + i)
+ for j in range(n_distributions):
+ pl.plot(x, A[:, j])
+ if i == nbm2:
+ pl.title('Distributions')
+ pl.xticks(())
+ pl.yticks(())
+
+ pl.subplot(2, nbm, 1 + i + nbm)
+
+ pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')
+ pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+ pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+ if i == nbm - 1:
+ pl.legend()
+ if i == nbm2:
+ pl.title('Barycenters')
+
+ pl.xticks(())
+ pl.yticks(())
diff --git a/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.rst b/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.rst
new file mode 100644
index 0000000..bd1c710
--- /dev/null
+++ b/docs/source/auto_examples/plot_barycenter_lp_vs_entropic.rst
@@ -0,0 +1,447 @@
+
+
+.. _sphx_glr_auto_examples_plot_barycenter_lp_vs_entropic.py:
+
+
+=================================================================================
+1D Wasserstein barycenter comparison between exact LP and entropic regularization
+=================================================================================
+
+This example illustrates the computation of regularized Wasserstein Barycenter
+as proposed in [3] and exact LP barycenters using standard LP solver.
+
+It reproduces approximately Figure 3.1 and 3.2 from the following paper:
+Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational
+Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.
+
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ # necessary for 3d plot even if not used
+ from mpl_toolkits.mplot3d import Axes3D # noqa
+ from matplotlib.collections import PolyCollection # noqa
+
+ #import ot.lp.cvx as cvx
+
+
+
+
+
+
+
+Gaussian Data
+-------------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters
+
+ problems = []
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ # Gaussian distributions
+ a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ n_distributions = A.shape[1]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+
+ #%% plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+ pl.tight_layout()
+
+ #%% barycenter computation
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # l2bary
+ bary_l2 = A.dot(weights)
+
+ # wasserstein
+ reg = 1e-3
+ ot.tic()
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ ot.toc()
+
+
+ ot.tic()
+ bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ ot.toc()
+
+ pl.figure(2)
+ pl.clf()
+ pl.subplot(2, 1, 1)
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+
+ pl.subplot(2, 1, 2)
+ pl.plot(x, bary_l2, 'r', label='l2')
+ pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+ pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+ pl.legend()
+ pl.title('Barycenters')
+ pl.tight_layout()
+
+ problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_002.png
+ :scale: 47
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ Elapsed time : 0.010712385177612305 s
+ Primal Feasibility Dual Feasibility Duality Gap Step Path Parameter Objective
+ 1.0 1.0 1.0 - 1.0 1700.336700337
+ 0.006776453137632 0.006776453137633 0.006776453137633 0.9932238647293 0.006776453137633 125.6700527543
+ 0.004018712867874 0.004018712867874 0.004018712867874 0.4301142633 0.004018712867874 12.26594150093
+ 0.001172775061627 0.001172775061627 0.001172775061627 0.7599932455029 0.001172775061627 0.3378536968897
+ 0.0004375137005385 0.0004375137005385 0.0004375137005385 0.6422331807989 0.0004375137005385 0.1468420566358
+ 0.000232669046734 0.0002326690467341 0.000232669046734 0.5016999460893 0.000232669046734 0.09381703231432
+ 7.430121674303e-05 7.430121674303e-05 7.430121674303e-05 0.7035962305812 7.430121674303e-05 0.0577787025717
+ 5.321227838876e-05 5.321227838875e-05 5.321227838876e-05 0.308784186441 5.321227838876e-05 0.05266249477203
+ 1.990900379199e-05 1.990900379196e-05 1.990900379199e-05 0.6520472013244 1.990900379199e-05 0.04526054405519
+ 6.305442046799e-06 6.30544204682e-06 6.3054420468e-06 0.7073953304075 6.305442046798e-06 0.04237597591383
+ 2.290148391577e-06 2.290148391582e-06 2.290148391578e-06 0.6941812711492 2.29014839159e-06 0.041522849321
+ 1.182864875387e-06 1.182864875406e-06 1.182864875427e-06 0.508455204675 1.182864875445e-06 0.04129461872827
+ 3.626786381529e-07 3.626786382468e-07 3.626786382923e-07 0.7101651572101 3.62678638267e-07 0.04113032448923
+ 1.539754244902e-07 1.539754249276e-07 1.539754249356e-07 0.6279322066282 1.539754253892e-07 0.04108867636379
+ 5.193221323143e-08 5.193221463044e-08 5.193221462729e-08 0.6843453436759 5.193221708199e-08 0.04106859618414
+ 1.888205054507e-08 1.888204779723e-08 1.88820477688e-08 0.6673444085651 1.888205650952e-08 0.041062141752
+ 5.676855206925e-09 5.676854518888e-09 5.676854517651e-09 0.7281705804232 5.676885442702e-09 0.04105958648713
+ 3.501157668218e-09 3.501150243546e-09 3.501150216347e-09 0.414020345194 3.501164437194e-09 0.04105916265261
+ 1.110594251499e-09 1.110590786827e-09 1.11059083379e-09 0.6998954759911 1.110636623476e-09 0.04105870073485
+ 5.770971626386e-10 5.772456113791e-10 5.772456200156e-10 0.4999769658132 5.77013379477e-10 0.04105859769135
+ 1.535218204536e-10 1.536993317032e-10 1.536992771966e-10 0.7516471627141 1.536205005991e-10 0.04105851679958
+ 6.724209350756e-11 6.739211232927e-11 6.739210470901e-11 0.5944802416166 6.735465384341e-11 0.04105850033766
+ 1.743382199199e-11 1.736445896691e-11 1.736448490761e-11 0.7573407808104 1.734254328931e-11 0.04105849088824
+ Optimization terminated successfully.
+ Elapsed time : 2.883899211883545 s
+
+
+Dirac Data
+----------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters
+
+ a1 = 1.0 * (x > 10) * (x < 50)
+ a2 = 1.0 * (x > 60) * (x < 80)
+
+ a1 /= a1.sum()
+ a2 /= a2.sum()
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ n_distributions = A.shape[1]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+
+ #%% plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+ pl.tight_layout()
+
+
+ #%% barycenter computation
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # l2bary
+ bary_l2 = A.dot(weights)
+
+ # wasserstein
+ reg = 1e-3
+ ot.tic()
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ ot.toc()
+
+
+ ot.tic()
+ bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ ot.toc()
+
+
+ problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+ pl.figure(2)
+ pl.clf()
+ pl.subplot(2, 1, 1)
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+
+ pl.subplot(2, 1, 2)
+ pl.plot(x, bary_l2, 'r', label='l2')
+ pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+ pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+ pl.legend()
+ pl.title('Barycenters')
+ pl.tight_layout()
+
+ #%% parameters
+
+ a1 = np.zeros(n)
+ a2 = np.zeros(n)
+
+ a1[10] = .25
+ a1[20] = .5
+ a1[30] = .25
+ a2[80] = 1
+
+
+ a1 /= a1.sum()
+ a2 /= a2.sum()
+
+ # creating matrix A containing all distributions
+ A = np.vstack((a1, a2)).T
+ n_distributions = A.shape[1]
+
+ # loss matrix + normalization
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+
+ #%% plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+ pl.tight_layout()
+
+
+ #%% barycenter computation
+
+ alpha = 0.5 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
+
+ # l2bary
+ bary_l2 = A.dot(weights)
+
+ # wasserstein
+ reg = 1e-3
+ ot.tic()
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights)
+ ot.toc()
+
+
+ ot.tic()
+ bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
+ ot.toc()
+
+
+ problems.append([A, [bary_l2, bary_wass, bary_wass2]])
+
+ pl.figure(2)
+ pl.clf()
+ pl.subplot(2, 1, 1)
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
+ pl.title('Distributions')
+
+ pl.subplot(2, 1, 2)
+ pl.plot(x, bary_l2, 'r', label='l2')
+ pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+ pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+ pl.legend()
+ pl.title('Barycenters')
+ pl.tight_layout()
+
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_003.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_004.png
+ :scale: 47
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ Elapsed time : 0.014938592910766602 s
+ Primal Feasibility Dual Feasibility Duality Gap Step Path Parameter Objective
+ 1.0 1.0 1.0 - 1.0 1700.336700337
+ 0.006776466288966 0.006776466288966 0.006776466288966 0.9932238515788 0.006776466288966 125.6649255808
+ 0.004036918865495 0.004036918865495 0.004036918865495 0.4272973099316 0.004036918865495 12.3471617011
+ 0.00121923268707 0.00121923268707 0.00121923268707 0.749698685599 0.00121923268707 0.3243835647408
+ 0.0003837422984432 0.0003837422984432 0.0003837422984432 0.6926882608284 0.0003837422984432 0.1361719397493
+ 0.0001070128410183 0.0001070128410183 0.0001070128410183 0.7643889137854 0.0001070128410183 0.07581952832518
+ 0.0001001275033711 0.0001001275033711 0.0001001275033711 0.07058704837812 0.0001001275033712 0.0734739493635
+ 4.550897507844e-05 4.550897507841e-05 4.550897507844e-05 0.5761172484828 4.550897507845e-05 0.05555077655047
+ 8.557124125522e-06 8.5571241255e-06 8.557124125522e-06 0.8535925441152 8.557124125522e-06 0.04439814660221
+ 3.611995628407e-06 3.61199562841e-06 3.611995628414e-06 0.6002277331554 3.611995628415e-06 0.04283007762152
+ 7.590393750365e-07 7.590393750491e-07 7.590393750378e-07 0.8221486533416 7.590393750381e-07 0.04192322976248
+ 8.299929287441e-08 8.299929286079e-08 8.299929287532e-08 0.9017467938799 8.29992928758e-08 0.04170825633295
+ 3.117560203449e-10 3.117560130137e-10 3.11756019954e-10 0.997039969226 3.11756019952e-10 0.04168179329766
+ 1.559749653711e-14 1.558073160926e-14 1.559756940692e-14 0.9999499686183 1.559750643989e-14 0.04168169240444
+ Optimization terminated successfully.
+ Elapsed time : 2.642659902572632 s
+ Elapsed time : 0.002908945083618164 s
+ Primal Feasibility Dual Feasibility Duality Gap Step Path Parameter Objective
+ 1.0 1.0 1.0 - 1.0 1700.336700337
+ 0.006774675520727 0.006774675520727 0.006774675520727 0.9932256422636 0.006774675520727 125.6956034743
+ 0.002048208707562 0.002048208707562 0.002048208707562 0.7343095368143 0.002048208707562 5.213991622123
+ 0.000269736547478 0.0002697365474781 0.0002697365474781 0.8839403501193 0.000269736547478 0.505938390389
+ 6.832109993943e-05 6.832109993944e-05 6.832109993944e-05 0.7601171075965 6.832109993943e-05 0.2339657807272
+ 2.437682932219e-05 2.43768293222e-05 2.437682932219e-05 0.6663448297475 2.437682932219e-05 0.1471256246325
+ 1.13498321631e-05 1.134983216308e-05 1.13498321631e-05 0.5553643816404 1.13498321631e-05 0.1181584941171
+ 3.342312725885e-06 3.342312725884e-06 3.342312725885e-06 0.7238133571615 3.342312725885e-06 0.1006387519747
+ 7.078561231603e-07 7.078561231509e-07 7.078561231604e-07 0.8033142552512 7.078561231603e-07 0.09474734646269
+ 1.966870956916e-07 1.966870954537e-07 1.966870954468e-07 0.752547917788 1.966870954633e-07 0.09354342735766
+ 4.19989524849e-10 4.199895164852e-10 4.199895238758e-10 0.9984019849375 4.19989523951e-10 0.09310367785861
+ 2.101015938666e-14 2.100625691113e-14 2.101023853438e-14 0.999949974425 2.101023691864e-14 0.09310274466458
+ Optimization terminated successfully.
+ Elapsed time : 2.690450668334961 s
+
+
+Final figure
+------------
+
+
+
+
+.. code-block:: python
+
+
+ #%% plot
+
+ nbm = len(problems)
+ nbm2 = (nbm // 2)
+
+
+ pl.figure(2, (20, 6))
+ pl.clf()
+
+ for i in range(nbm):
+
+ A = problems[i][0]
+ bary_l2 = problems[i][1][0]
+ bary_wass = problems[i][1][1]
+ bary_wass2 = problems[i][1][2]
+
+ pl.subplot(2, nbm, 1 + i)
+ for j in range(n_distributions):
+ pl.plot(x, A[:, j])
+ if i == nbm2:
+ pl.title('Distributions')
+ pl.xticks(())
+ pl.yticks(())
+
+ pl.subplot(2, nbm, 1 + i + nbm)
+
+ pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')
+ pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
+ pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
+ if i == nbm - 1:
+ pl.legend()
+ if i == nbm2:
+ pl.title('Barycenters')
+
+ pl.xticks(())
+ pl.yticks(())
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_lp_vs_entropic_006.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 8.892 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_barycenter_lp_vs_entropic.py <plot_barycenter_lp_vs_entropic.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_barycenter_lp_vs_entropic.ipynb <plot_barycenter_lp_vs_entropic.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_compute_emd.ipynb b/docs/source/auto_examples/plot_compute_emd.ipynb
new file mode 100644
index 0000000..562eff8
--- /dev/null
+++ b/docs/source/auto_examples/plot_compute_emd.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Plot multiple EMD\n\n\nShows how to compute multiple EMD and Sinkhorn with two differnt\nground metrics and plot their values for diffeent distributions.\n\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import make_1D_gauss as gauss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\nn_target = 50 # nb target distributions\n\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\nlst_m = np.linspace(20, 90, n_target)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\n\nB = np.zeros((n, n_target))\n\nfor i, m in enumerate(lst_m):\n B[:, i] = gauss(n, m=m, s=5)\n\n# loss matrix and normalization\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')\nM /= M.max()\nM2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')\nM2 /= M2.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot the distributions\n\npl.figure(1)\npl.subplot(2, 1, 1)\npl.plot(x, a, 'b', label='Source distribution')\npl.title('Source distribution')\npl.subplot(2, 1, 2)\npl.plot(x, B, label='Target distributions')\npl.title('Target distributions')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute EMD for the different losses\n------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Compute and plot distributions and loss matrix\n\nd_emd = ot.emd2(a, B, M) # direct computation of EMD\nd_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2\n\n\npl.figure(2)\npl.plot(d_emd, label='Euclidean EMD')\npl.plot(d_emd2, label='Squared Euclidean EMD')\npl.title('EMD distances')\npl.legend()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute Sinkhorn for the different losses\n-----------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%%\nreg = 1e-2\nd_sinkhorn = ot.sinkhorn2(a, B, M, reg)\nd_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)\n\npl.figure(2)\npl.clf()\npl.plot(d_emd, label='Euclidean EMD')\npl.plot(d_emd2, label='Squared Euclidean EMD')\npl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')\npl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')\npl.title('EMD distances')\npl.legend()\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_compute_emd.py b/docs/source/auto_examples/plot_compute_emd.py
new file mode 100644
index 0000000..7ed2b01
--- /dev/null
+++ b/docs/source/auto_examples/plot_compute_emd.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+"""
+=================
+Plot multiple EMD
+=================
+
+Shows how to compute multiple EMD and Sinkhorn with two differnt
+ground metrics and plot their values for diffeent distributions.
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+from ot.datasets import make_1D_gauss as gauss
+
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+n_target = 50 # nb target distributions
+
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+lst_m = np.linspace(20, 90, n_target)
+
+# Gaussian distributions
+a = gauss(n, m=20, s=5) # m= mean, s= std
+
+B = np.zeros((n, n_target))
+
+for i, m in enumerate(lst_m):
+ B[:, i] = gauss(n, m=m, s=5)
+
+# loss matrix and normalization
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
+M /= M.max()
+M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
+M2 /= M2.max()
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.figure(1)
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'b', label='Source distribution')
+pl.title('Source distribution')
+pl.subplot(2, 1, 2)
+pl.plot(x, B, label='Target distributions')
+pl.title('Target distributions')
+pl.tight_layout()
+
+
+##############################################################################
+# Compute EMD for the different losses
+# ------------------------------------
+
+#%% Compute and plot distributions and loss matrix
+
+d_emd = ot.emd2(a, B, M) # direct computation of EMD
+d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
+
+
+pl.figure(2)
+pl.plot(d_emd, label='Euclidean EMD')
+pl.plot(d_emd2, label='Squared Euclidean EMD')
+pl.title('EMD distances')
+pl.legend()
+
+##############################################################################
+# Compute Sinkhorn for the different losses
+# -----------------------------------------
+
+#%%
+reg = 1e-2
+d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
+d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
+
+pl.figure(2)
+pl.clf()
+pl.plot(d_emd, label='Euclidean EMD')
+pl.plot(d_emd2, label='Squared Euclidean EMD')
+pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
+pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
+pl.title('EMD distances')
+pl.legend()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_compute_emd.rst b/docs/source/auto_examples/plot_compute_emd.rst
new file mode 100644
index 0000000..27bca2c
--- /dev/null
+++ b/docs/source/auto_examples/plot_compute_emd.rst
@@ -0,0 +1,189 @@
+
+
+.. _sphx_glr_auto_examples_plot_compute_emd.py:
+
+
+=================
+Plot multiple EMD
+=================
+
+Shows how to compute multiple EMD and Sinkhorn with two differnt
+ground metrics and plot their values for diffeent distributions.
+
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ from ot.datasets import make_1D_gauss as gauss
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters
+
+ n = 100 # nb bins
+ n_target = 50 # nb target distributions
+
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ lst_m = np.linspace(20, 90, n_target)
+
+ # Gaussian distributions
+ a = gauss(n, m=20, s=5) # m= mean, s= std
+
+ B = np.zeros((n, n_target))
+
+ for i, m in enumerate(lst_m):
+ B[:, i] = gauss(n, m=m, s=5)
+
+ # loss matrix and normalization
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
+ M /= M.max()
+ M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
+ M2 /= M2.max()
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% plot the distributions
+
+ pl.figure(1)
+ pl.subplot(2, 1, 1)
+ pl.plot(x, a, 'b', label='Source distribution')
+ pl.title('Source distribution')
+ pl.subplot(2, 1, 2)
+ pl.plot(x, B, label='Target distributions')
+ pl.title('Target distributions')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_compute_emd_001.png
+ :align: center
+
+
+
+
+Compute EMD for the different losses
+------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Compute and plot distributions and loss matrix
+
+ d_emd = ot.emd2(a, B, M) # direct computation of EMD
+ d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
+
+
+ pl.figure(2)
+ pl.plot(d_emd, label='Euclidean EMD')
+ pl.plot(d_emd2, label='Squared Euclidean EMD')
+ pl.title('EMD distances')
+ pl.legend()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_compute_emd_003.png
+ :align: center
+
+
+
+
+Compute Sinkhorn for the different losses
+-----------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%%
+ reg = 1e-2
+ d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
+ d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
+
+ pl.figure(2)
+ pl.clf()
+ pl.plot(d_emd, label='Euclidean EMD')
+ pl.plot(d_emd2, label='Squared Euclidean EMD')
+ pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
+ pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
+ pl.title('EMD distances')
+ pl.legend()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_compute_emd_004.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.446 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_compute_emd.py <plot_compute_emd.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_compute_emd.ipynb <plot_compute_emd.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_convolutional_barycenter.ipynb b/docs/source/auto_examples/plot_convolutional_barycenter.ipynb
new file mode 100644
index 0000000..4981ba3
--- /dev/null
+++ b/docs/source/auto_examples/plot_convolutional_barycenter.ipynb
@@ -0,0 +1,90 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Convolutional Wasserstein Barycenter example\n\n\nThis example is designed to illustrate how the Convolutional Wasserstein Barycenter\nfunction of POT works.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Nicolas Courty <ncourty@irisa.fr>\n#\n# License: MIT License\n\n\nimport numpy as np\nimport pylab as pl\nimport ot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Data preparation\n----------------\n\nThe four distributions are constructed from 4 simple images\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]\nf2 = 1 - pl.imread('../data/duck.png')[:, :, 2]\nf3 = 1 - pl.imread('../data/heart.png')[:, :, 2]\nf4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]\n\nA = []\nf1 = f1 / np.sum(f1)\nf2 = f2 / np.sum(f2)\nf3 = f3 / np.sum(f3)\nf4 = f4 / np.sum(f4)\nA.append(f1)\nA.append(f2)\nA.append(f3)\nA.append(f4)\nA = np.array(A)\n\nnb_images = 5\n\n# those are the four corners coordinates that will be interpolated by bilinear\n# interpolation\nv1 = np.array((1, 0, 0, 0))\nv2 = np.array((0, 1, 0, 0))\nv3 = np.array((0, 0, 1, 0))\nv4 = np.array((0, 0, 0, 1))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Barycenter computation and visualization\n----------------------------------------\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(figsize=(10, 10))\npl.title('Convolutional Wasserstein Barycenters in POT')\ncm = 'Blues'\n# regularization parameter\nreg = 0.004\nfor i in range(nb_images):\n for j in range(nb_images):\n pl.subplot(nb_images, nb_images, i * nb_images + j + 1)\n tx = float(i) / (nb_images - 1)\n ty = float(j) / (nb_images - 1)\n\n # weights are constructed by bilinear interpolation\n tmp1 = (1 - tx) * v1 + tx * v2\n tmp2 = (1 - tx) * v3 + tx * v4\n weights = (1 - ty) * tmp1 + ty * tmp2\n\n if i == 0 and j == 0:\n pl.imshow(f1, cmap=cm)\n pl.axis('off')\n elif i == 0 and j == (nb_images - 1):\n pl.imshow(f3, cmap=cm)\n pl.axis('off')\n elif i == (nb_images - 1) and j == 0:\n pl.imshow(f2, cmap=cm)\n pl.axis('off')\n elif i == (nb_images - 1) and j == (nb_images - 1):\n pl.imshow(f4, cmap=cm)\n pl.axis('off')\n else:\n # call to barycenter computation\n pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)\n pl.axis('off')\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_convolutional_barycenter.py b/docs/source/auto_examples/plot_convolutional_barycenter.py
new file mode 100644
index 0000000..e74db04
--- /dev/null
+++ b/docs/source/auto_examples/plot_convolutional_barycenter.py
@@ -0,0 +1,92 @@
+
+#%%
+# -*- coding: utf-8 -*-
+"""
+============================================
+Convolutional Wasserstein Barycenter example
+============================================
+
+This example is designed to illustrate how the Convolutional Wasserstein Barycenter
+function of POT works.
+"""
+
+# Author: Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import pylab as pl
+import ot
+
+##############################################################################
+# Data preparation
+# ----------------
+#
+# The four distributions are constructed from 4 simple images
+
+
+f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
+f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
+f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
+f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
+
+A = []
+f1 = f1 / np.sum(f1)
+f2 = f2 / np.sum(f2)
+f3 = f3 / np.sum(f3)
+f4 = f4 / np.sum(f4)
+A.append(f1)
+A.append(f2)
+A.append(f3)
+A.append(f4)
+A = np.array(A)
+
+nb_images = 5
+
+# those are the four corners coordinates that will be interpolated by bilinear
+# interpolation
+v1 = np.array((1, 0, 0, 0))
+v2 = np.array((0, 1, 0, 0))
+v3 = np.array((0, 0, 1, 0))
+v4 = np.array((0, 0, 0, 1))
+
+
+##############################################################################
+# Barycenter computation and visualization
+# ----------------------------------------
+#
+
+pl.figure(figsize=(10, 10))
+pl.title('Convolutional Wasserstein Barycenters in POT')
+cm = 'Blues'
+# regularization parameter
+reg = 0.004
+for i in range(nb_images):
+ for j in range(nb_images):
+ pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
+ tx = float(i) / (nb_images - 1)
+ ty = float(j) / (nb_images - 1)
+
+ # weights are constructed by bilinear interpolation
+ tmp1 = (1 - tx) * v1 + tx * v2
+ tmp2 = (1 - tx) * v3 + tx * v4
+ weights = (1 - ty) * tmp1 + ty * tmp2
+
+ if i == 0 and j == 0:
+ pl.imshow(f1, cmap=cm)
+ pl.axis('off')
+ elif i == 0 and j == (nb_images - 1):
+ pl.imshow(f3, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == 0:
+ pl.imshow(f2, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == (nb_images - 1):
+ pl.imshow(f4, cmap=cm)
+ pl.axis('off')
+ else:
+ # call to barycenter computation
+ pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
+ pl.axis('off')
+pl.show()
diff --git a/docs/source/auto_examples/plot_convolutional_barycenter.rst b/docs/source/auto_examples/plot_convolutional_barycenter.rst
new file mode 100644
index 0000000..a28db2f
--- /dev/null
+++ b/docs/source/auto_examples/plot_convolutional_barycenter.rst
@@ -0,0 +1,151 @@
+
+
+.. _sphx_glr_auto_examples_plot_convolutional_barycenter.py:
+
+
+============================================
+Convolutional Wasserstein Barycenter example
+============================================
+
+This example is designed to illustrate how the Convolutional Wasserstein Barycenter
+function of POT works.
+
+
+
+.. code-block:: python
+
+
+ # Author: Nicolas Courty <ncourty@irisa.fr>
+ #
+ # License: MIT License
+
+
+ import numpy as np
+ import pylab as pl
+ import ot
+
+
+
+
+
+
+
+Data preparation
+----------------
+
+The four distributions are constructed from 4 simple images
+
+
+
+.. code-block:: python
+
+
+
+ f1 = 1 - pl.imread('../data/redcross.png')[:, :, 2]
+ f2 = 1 - pl.imread('../data/duck.png')[:, :, 2]
+ f3 = 1 - pl.imread('../data/heart.png')[:, :, 2]
+ f4 = 1 - pl.imread('../data/tooth.png')[:, :, 2]
+
+ A = []
+ f1 = f1 / np.sum(f1)
+ f2 = f2 / np.sum(f2)
+ f3 = f3 / np.sum(f3)
+ f4 = f4 / np.sum(f4)
+ A.append(f1)
+ A.append(f2)
+ A.append(f3)
+ A.append(f4)
+ A = np.array(A)
+
+ nb_images = 5
+
+ # those are the four corners coordinates that will be interpolated by bilinear
+ # interpolation
+ v1 = np.array((1, 0, 0, 0))
+ v2 = np.array((0, 1, 0, 0))
+ v3 = np.array((0, 0, 1, 0))
+ v4 = np.array((0, 0, 0, 1))
+
+
+
+
+
+
+
+
+Barycenter computation and visualization
+----------------------------------------
+
+
+
+
+.. code-block:: python
+
+
+ pl.figure(figsize=(10, 10))
+ pl.title('Convolutional Wasserstein Barycenters in POT')
+ cm = 'Blues'
+ # regularization parameter
+ reg = 0.004
+ for i in range(nb_images):
+ for j in range(nb_images):
+ pl.subplot(nb_images, nb_images, i * nb_images + j + 1)
+ tx = float(i) / (nb_images - 1)
+ ty = float(j) / (nb_images - 1)
+
+ # weights are constructed by bilinear interpolation
+ tmp1 = (1 - tx) * v1 + tx * v2
+ tmp2 = (1 - tx) * v3 + tx * v4
+ weights = (1 - ty) * tmp1 + ty * tmp2
+
+ if i == 0 and j == 0:
+ pl.imshow(f1, cmap=cm)
+ pl.axis('off')
+ elif i == 0 and j == (nb_images - 1):
+ pl.imshow(f3, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == 0:
+ pl.imshow(f2, cmap=cm)
+ pl.axis('off')
+ elif i == (nb_images - 1) and j == (nb_images - 1):
+ pl.imshow(f4, cmap=cm)
+ pl.axis('off')
+ else:
+ # call to barycenter computation
+ pl.imshow(ot.bregman.convolutional_barycenter2d(A, reg, weights), cmap=cm)
+ pl.axis('off')
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_convolutional_barycenter_001.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 1 minutes 11.608 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_convolutional_barycenter.py <plot_convolutional_barycenter.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_convolutional_barycenter.ipynb <plot_convolutional_barycenter.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_fgw.ipynb b/docs/source/auto_examples/plot_fgw.ipynb
new file mode 100644
index 0000000..1b150bd
--- /dev/null
+++ b/docs/source/auto_examples/plot_fgw.ipynb
@@ -0,0 +1,162 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Plot Fused-gromov-Wasserstein\n\n\nThis example illustrates the computation of FGW for 1D measures[18].\n\n.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain\n and Courty Nicolas\n \"Optimal Transport for structured data with application on graphs\"\n International Conference on Machine Learning (ICML). 2019.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Titouan Vayer <titouan.vayer@irisa.fr>\n#\n# License: MIT License\n\nimport matplotlib.pyplot as pl\nimport numpy as np\nimport ot\nfrom ot.gromov import gromov_wasserstein, fused_gromov_wasserstein"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n# We create two 1D random measures\nn = 20 # number of points in the first distribution\nn2 = 30 # number of points in the second distribution\nsig = 1 # std of first distribution\nsig2 = 0.1 # std of second distribution\n\nnp.random.seed(0)\n\nphi = np.arange(n)[:, None]\nxs = phi + sig * np.random.randn(n, 1)\nys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)\n\nphi2 = np.arange(n2)[:, None]\nxt = phi2 + sig * np.random.randn(n2, 1)\nyt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)\nyt = yt[::-1, :]\n\np = ot.unif(n)\nq = ot.unif(n2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% plot the distributions\n\npl.close(10)\npl.figure(10, (7, 7))\n\npl.subplot(2, 1, 1)\n\npl.scatter(ys, xs, c=phi, s=70)\npl.ylabel('Feature value a', fontsize=20)\npl.title('$\\mu=\\sum_i \\delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)\npl.xticks(())\npl.yticks(())\npl.subplot(2, 1, 2)\npl.scatter(yt, xt, c=phi2, s=70)\npl.xlabel('coordinates x/y', fontsize=25)\npl.ylabel('Feature value b', fontsize=20)\npl.title('$\\\\nu=\\sum_j \\delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)\npl.yticks(())\npl.tight_layout()\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Create structure matrices and across-feature distance matrix\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Structure matrices and across-features distance matrix\nC1 = ot.dist(xs)\nC2 = ot.dist(xt)\nM = ot.dist(ys, yt)\nw1 = ot.unif(C1.shape[0])\nw2 = ot.unif(C2.shape[0])\nGot = ot.emd([], [], M)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot matrices\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%%\ncmap = 'Reds'\npl.close(10)\npl.figure(10, (5, 5))\nfs = 15\nl_x = [0, 5, 10, 15]\nl_y = [0, 5, 10, 15, 20, 25]\ngs = pl.GridSpec(5, 5)\n\nax1 = pl.subplot(gs[3:, :2])\n\npl.imshow(C1, cmap=cmap, interpolation='nearest')\npl.title(\"$C_1$\", fontsize=fs)\npl.xlabel(\"$k$\", fontsize=fs)\npl.ylabel(\"$i$\", fontsize=fs)\npl.xticks(l_x)\npl.yticks(l_x)\n\nax2 = pl.subplot(gs[:3, 2:])\n\npl.imshow(C2, cmap=cmap, interpolation='nearest')\npl.title(\"$C_2$\", fontsize=fs)\npl.ylabel(\"$l$\", fontsize=fs)\n#pl.ylabel(\"$l$\",fontsize=fs)\npl.xticks(())\npl.yticks(l_y)\nax2.set_aspect('auto')\n\nax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)\npl.imshow(M, cmap=cmap, interpolation='nearest')\npl.yticks(l_x)\npl.xticks(l_y)\npl.ylabel(\"$i$\", fontsize=fs)\npl.title(\"$M_{AB}$\", fontsize=fs)\npl.xlabel(\"$j$\", fontsize=fs)\npl.tight_layout()\nax3.set_aspect('auto')\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute FGW/GW\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Computing FGW and GW\nalpha = 1e-3\n\not.tic()\nGwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)\not.toc()\n\n#%reload_ext WGW\nGg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Visualize transport matrices\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% visu OT matrix\ncmap = 'Blues'\nfs = 15\npl.figure(2, (13, 5))\npl.clf()\npl.subplot(1, 3, 1)\npl.imshow(Got, cmap=cmap, interpolation='nearest')\n#pl.xlabel(\"$y$\",fontsize=fs)\npl.ylabel(\"$i$\", fontsize=fs)\npl.xticks(())\n\npl.title('Wasserstein ($M$ only)')\n\npl.subplot(1, 3, 2)\npl.imshow(Gg, cmap=cmap, interpolation='nearest')\npl.title('Gromov ($C_1,C_2$ only)')\npl.xticks(())\npl.subplot(1, 3, 3)\npl.imshow(Gwg, cmap=cmap, interpolation='nearest')\npl.title('FGW ($M+C_1,C_2$)')\n\npl.xlabel(\"$j$\", fontsize=fs)\npl.ylabel(\"$i$\", fontsize=fs)\n\npl.tight_layout()\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_fgw.py b/docs/source/auto_examples/plot_fgw.py
new file mode 100644
index 0000000..43efc94
--- /dev/null
+++ b/docs/source/auto_examples/plot_fgw.py
@@ -0,0 +1,173 @@
+# -*- coding: utf-8 -*-
+"""
+==============================
+Plot Fused-gromov-Wasserstein
+==============================
+
+This example illustrates the computation of FGW for 1D measures[18].
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+"""
+
+# Author: Titouan Vayer <titouan.vayer@irisa.fr>
+#
+# License: MIT License
+
+import matplotlib.pyplot as pl
+import numpy as np
+import ot
+from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+
+##############################################################################
+# Generate data
+# ---------
+
+#%% parameters
+# We create two 1D random measures
+n = 20 # number of points in the first distribution
+n2 = 30 # number of points in the second distribution
+sig = 1 # std of first distribution
+sig2 = 0.1 # std of second distribution
+
+np.random.seed(0)
+
+phi = np.arange(n)[:, None]
+xs = phi + sig * np.random.randn(n, 1)
+ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)
+
+phi2 = np.arange(n2)[:, None]
+xt = phi2 + sig * np.random.randn(n2, 1)
+yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)
+yt = yt[::-1, :]
+
+p = ot.unif(n)
+q = ot.unif(n2)
+
+##############################################################################
+# Plot data
+# ---------
+
+#%% plot the distributions
+
+pl.close(10)
+pl.figure(10, (7, 7))
+
+pl.subplot(2, 1, 1)
+
+pl.scatter(ys, xs, c=phi, s=70)
+pl.ylabel('Feature value a', fontsize=20)
+pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)
+pl.xticks(())
+pl.yticks(())
+pl.subplot(2, 1, 2)
+pl.scatter(yt, xt, c=phi2, s=70)
+pl.xlabel('coordinates x/y', fontsize=25)
+pl.ylabel('Feature value b', fontsize=20)
+pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)
+pl.yticks(())
+pl.tight_layout()
+pl.show()
+
+##############################################################################
+# Create structure matrices and across-feature distance matrix
+# ---------
+
+#%% Structure matrices and across-features distance matrix
+C1 = ot.dist(xs)
+C2 = ot.dist(xt)
+M = ot.dist(ys, yt)
+w1 = ot.unif(C1.shape[0])
+w2 = ot.unif(C2.shape[0])
+Got = ot.emd([], [], M)
+
+##############################################################################
+# Plot matrices
+# ---------
+
+#%%
+cmap = 'Reds'
+pl.close(10)
+pl.figure(10, (5, 5))
+fs = 15
+l_x = [0, 5, 10, 15]
+l_y = [0, 5, 10, 15, 20, 25]
+gs = pl.GridSpec(5, 5)
+
+ax1 = pl.subplot(gs[3:, :2])
+
+pl.imshow(C1, cmap=cmap, interpolation='nearest')
+pl.title("$C_1$", fontsize=fs)
+pl.xlabel("$k$", fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+pl.xticks(l_x)
+pl.yticks(l_x)
+
+ax2 = pl.subplot(gs[:3, 2:])
+
+pl.imshow(C2, cmap=cmap, interpolation='nearest')
+pl.title("$C_2$", fontsize=fs)
+pl.ylabel("$l$", fontsize=fs)
+#pl.ylabel("$l$",fontsize=fs)
+pl.xticks(())
+pl.yticks(l_y)
+ax2.set_aspect('auto')
+
+ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)
+pl.imshow(M, cmap=cmap, interpolation='nearest')
+pl.yticks(l_x)
+pl.xticks(l_y)
+pl.ylabel("$i$", fontsize=fs)
+pl.title("$M_{AB}$", fontsize=fs)
+pl.xlabel("$j$", fontsize=fs)
+pl.tight_layout()
+ax3.set_aspect('auto')
+pl.show()
+
+##############################################################################
+# Compute FGW/GW
+# ---------
+
+#%% Computing FGW and GW
+alpha = 1e-3
+
+ot.tic()
+Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
+ot.toc()
+
+#%reload_ext WGW
+Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
+
+##############################################################################
+# Visualize transport matrices
+# ---------
+
+#%% visu OT matrix
+cmap = 'Blues'
+fs = 15
+pl.figure(2, (13, 5))
+pl.clf()
+pl.subplot(1, 3, 1)
+pl.imshow(Got, cmap=cmap, interpolation='nearest')
+#pl.xlabel("$y$",fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+pl.xticks(())
+
+pl.title('Wasserstein ($M$ only)')
+
+pl.subplot(1, 3, 2)
+pl.imshow(Gg, cmap=cmap, interpolation='nearest')
+pl.title('Gromov ($C_1,C_2$ only)')
+pl.xticks(())
+pl.subplot(1, 3, 3)
+pl.imshow(Gwg, cmap=cmap, interpolation='nearest')
+pl.title('FGW ($M+C_1,C_2$)')
+
+pl.xlabel("$j$", fontsize=fs)
+pl.ylabel("$i$", fontsize=fs)
+
+pl.tight_layout()
+pl.show()
diff --git a/docs/source/auto_examples/plot_fgw.rst b/docs/source/auto_examples/plot_fgw.rst
new file mode 100644
index 0000000..aec725d
--- /dev/null
+++ b/docs/source/auto_examples/plot_fgw.rst
@@ -0,0 +1,297 @@
+
+
+.. _sphx_glr_auto_examples_plot_fgw.py:
+
+
+==============================
+Plot Fused-gromov-Wasserstein
+==============================
+
+This example illustrates the computation of FGW for 1D measures[18].
+
+.. [18] Vayer Titouan, Chapel Laetitia, Flamary R{'e}mi, Tavenard Romain
+ and Courty Nicolas
+ "Optimal Transport for structured data with application on graphs"
+ International Conference on Machine Learning (ICML). 2019.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Titouan Vayer <titouan.vayer@irisa.fr>
+ #
+ # License: MIT License
+
+ import matplotlib.pyplot as pl
+ import numpy as np
+ import ot
+ from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
+
+
+
+
+
+
+
+Generate data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters
+ # We create two 1D random measures
+ n = 20 # number of points in the first distribution
+ n2 = 30 # number of points in the second distribution
+ sig = 1 # std of first distribution
+ sig2 = 0.1 # std of second distribution
+
+ np.random.seed(0)
+
+ phi = np.arange(n)[:, None]
+ xs = phi + sig * np.random.randn(n, 1)
+ ys = np.vstack((np.ones((n // 2, 1)), 0 * np.ones((n // 2, 1)))) + sig2 * np.random.randn(n, 1)
+
+ phi2 = np.arange(n2)[:, None]
+ xt = phi2 + sig * np.random.randn(n2, 1)
+ yt = np.vstack((np.ones((n2 // 2, 1)), 0 * np.ones((n2 // 2, 1)))) + sig2 * np.random.randn(n2, 1)
+ yt = yt[::-1, :]
+
+ p = ot.unif(n)
+ q = ot.unif(n2)
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% plot the distributions
+
+ pl.close(10)
+ pl.figure(10, (7, 7))
+
+ pl.subplot(2, 1, 1)
+
+ pl.scatter(ys, xs, c=phi, s=70)
+ pl.ylabel('Feature value a', fontsize=20)
+ pl.title('$\mu=\sum_i \delta_{x_i,a_i}$', fontsize=25, usetex=True, y=1)
+ pl.xticks(())
+ pl.yticks(())
+ pl.subplot(2, 1, 2)
+ pl.scatter(yt, xt, c=phi2, s=70)
+ pl.xlabel('coordinates x/y', fontsize=25)
+ pl.ylabel('Feature value b', fontsize=20)
+ pl.title('$\\nu=\sum_j \delta_{y_j,b_j}$', fontsize=25, usetex=True, y=1)
+ pl.yticks(())
+ pl.tight_layout()
+ pl.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_fgw_010.png
+ :align: center
+
+
+
+
+Create structure matrices and across-feature distance matrix
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% Structure matrices and across-features distance matrix
+ C1 = ot.dist(xs)
+ C2 = ot.dist(xt)
+ M = ot.dist(ys, yt)
+ w1 = ot.unif(C1.shape[0])
+ w2 = ot.unif(C2.shape[0])
+ Got = ot.emd([], [], M)
+
+
+
+
+
+
+
+Plot matrices
+---------
+
+
+
+.. code-block:: python
+
+
+ #%%
+ cmap = 'Reds'
+ pl.close(10)
+ pl.figure(10, (5, 5))
+ fs = 15
+ l_x = [0, 5, 10, 15]
+ l_y = [0, 5, 10, 15, 20, 25]
+ gs = pl.GridSpec(5, 5)
+
+ ax1 = pl.subplot(gs[3:, :2])
+
+ pl.imshow(C1, cmap=cmap, interpolation='nearest')
+ pl.title("$C_1$", fontsize=fs)
+ pl.xlabel("$k$", fontsize=fs)
+ pl.ylabel("$i$", fontsize=fs)
+ pl.xticks(l_x)
+ pl.yticks(l_x)
+
+ ax2 = pl.subplot(gs[:3, 2:])
+
+ pl.imshow(C2, cmap=cmap, interpolation='nearest')
+ pl.title("$C_2$", fontsize=fs)
+ pl.ylabel("$l$", fontsize=fs)
+ #pl.ylabel("$l$",fontsize=fs)
+ pl.xticks(())
+ pl.yticks(l_y)
+ ax2.set_aspect('auto')
+
+ ax3 = pl.subplot(gs[3:, 2:], sharex=ax2, sharey=ax1)
+ pl.imshow(M, cmap=cmap, interpolation='nearest')
+ pl.yticks(l_x)
+ pl.xticks(l_y)
+ pl.ylabel("$i$", fontsize=fs)
+ pl.title("$M_{AB}$", fontsize=fs)
+ pl.xlabel("$j$", fontsize=fs)
+ pl.tight_layout()
+ ax3.set_aspect('auto')
+ pl.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_fgw_011.png
+ :align: center
+
+
+
+
+Compute FGW/GW
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% Computing FGW and GW
+ alpha = 1e-3
+
+ ot.tic()
+ Gwg, logw = fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=alpha, verbose=True, log=True)
+ ot.toc()
+
+ #%reload_ext WGW
+ Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Relative loss|Absolute loss
+ ------------------------------------------------
+ 0|4.734462e+01|0.000000e+00|0.000000e+00
+ 1|2.508258e+01|8.875498e-01|2.226204e+01
+ 2|2.189329e+01|1.456747e-01|3.189297e+00
+ 3|2.189329e+01|0.000000e+00|0.000000e+00
+ Elapsed time : 0.0016989707946777344 s
+ It. |Loss |Relative loss|Absolute loss
+ ------------------------------------------------
+ 0|4.683978e+04|0.000000e+00|0.000000e+00
+ 1|3.860061e+04|2.134468e-01|8.239175e+03
+ 2|2.182948e+04|7.682787e-01|1.677113e+04
+ 3|2.182948e+04|0.000000e+00|0.000000e+00
+
+
+Visualize transport matrices
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% visu OT matrix
+ cmap = 'Blues'
+ fs = 15
+ pl.figure(2, (13, 5))
+ pl.clf()
+ pl.subplot(1, 3, 1)
+ pl.imshow(Got, cmap=cmap, interpolation='nearest')
+ #pl.xlabel("$y$",fontsize=fs)
+ pl.ylabel("$i$", fontsize=fs)
+ pl.xticks(())
+
+ pl.title('Wasserstein ($M$ only)')
+
+ pl.subplot(1, 3, 2)
+ pl.imshow(Gg, cmap=cmap, interpolation='nearest')
+ pl.title('Gromov ($C_1,C_2$ only)')
+ pl.xticks(())
+ pl.subplot(1, 3, 3)
+ pl.imshow(Gwg, cmap=cmap, interpolation='nearest')
+ pl.title('FGW ($M+C_1,C_2$)')
+
+ pl.xlabel("$j$", fontsize=fs)
+ pl.ylabel("$i$", fontsize=fs)
+
+ pl.tight_layout()
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_fgw_004.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 1.468 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_fgw.py <plot_fgw.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_fgw.ipynb <plot_fgw.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_free_support_barycenter.ipynb b/docs/source/auto_examples/plot_free_support_barycenter.ipynb
new file mode 100644
index 0000000..05a81c8
--- /dev/null
+++ b/docs/source/auto_examples/plot_free_support_barycenter.ipynb
@@ -0,0 +1,108 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# 2D free support Wasserstein barycenters of distributions\n\n\nIllustration of 2D Wasserstein barycenters if discributions that are weighted\nsum of diracs.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n -------------\n%% parameters and data generation\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "N = 3\nd = 2\nmeasures_locations = []\nmeasures_weights = []\n\nfor i in range(N):\n\n n_i = np.random.randint(low=1, high=20) # nb samples\n\n mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean\n\n A_i = np.random.rand(d, d)\n cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix\n\n x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations\n b_i = np.random.uniform(0., 1., (n_i,))\n b_i = b_i / np.sum(b_i) # Dirac weights\n\n measures_locations.append(x_i)\n measures_weights.append(b_i)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute free support barycenter\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "k = 10 # number of Diracs of the barycenter\nX_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations\nb = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)\n\nX = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1)\nfor (x_i, b_i) in zip(measures_locations, measures_weights):\n color = np.random.randint(low=1, high=10 * N)\n pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')\npl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')\npl.title('Data measures and their barycenter')\npl.legend(loc=0)\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_free_support_barycenter.py b/docs/source/auto_examples/plot_free_support_barycenter.py
new file mode 100644
index 0000000..b6efc59
--- /dev/null
+++ b/docs/source/auto_examples/plot_free_support_barycenter.py
@@ -0,0 +1,69 @@
+# -*- coding: utf-8 -*-
+"""
+====================================================
+2D free support Wasserstein barycenters of distributions
+====================================================
+
+Illustration of 2D Wasserstein barycenters if discributions that are weighted
+sum of diracs.
+
+"""
+
+# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+#%% parameters and data generation
+N = 3
+d = 2
+measures_locations = []
+measures_weights = []
+
+for i in range(N):
+
+ n_i = np.random.randint(low=1, high=20) # nb samples
+
+ mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+
+ A_i = np.random.rand(d, d)
+ cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+
+ x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
+ b_i = np.random.uniform(0., 1., (n_i,))
+ b_i = b_i / np.sum(b_i) # Dirac weights
+
+ measures_locations.append(x_i)
+ measures_weights.append(b_i)
+
+
+##############################################################################
+# Compute free support barycenter
+# -------------
+
+k = 10 # number of Diracs of the barycenter
+X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
+
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1)
+for (x_i, b_i) in zip(measures_locations, measures_weights):
+ color = np.random.randint(low=1, high=10 * N)
+ pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')
+pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+pl.title('Data measures and their barycenter')
+pl.legend(loc=0)
+pl.show()
diff --git a/docs/source/auto_examples/plot_free_support_barycenter.rst b/docs/source/auto_examples/plot_free_support_barycenter.rst
new file mode 100644
index 0000000..d1b3b80
--- /dev/null
+++ b/docs/source/auto_examples/plot_free_support_barycenter.rst
@@ -0,0 +1,140 @@
+
+
+.. _sphx_glr_auto_examples_plot_free_support_barycenter.py:
+
+
+====================================================
+2D free support Wasserstein barycenters of distributions
+====================================================
+
+Illustration of 2D Wasserstein barycenters if discributions that are weighted
+sum of diracs.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+
+Generate data
+ -------------
+%% parameters and data generation
+
+
+
+.. code-block:: python
+
+ N = 3
+ d = 2
+ measures_locations = []
+ measures_weights = []
+
+ for i in range(N):
+
+ n_i = np.random.randint(low=1, high=20) # nb samples
+
+ mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
+
+ A_i = np.random.rand(d, d)
+ cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
+
+ x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
+ b_i = np.random.uniform(0., 1., (n_i,))
+ b_i = b_i / np.sum(b_i) # Dirac weights
+
+ measures_locations.append(x_i)
+ measures_weights.append(b_i)
+
+
+
+
+
+
+
+
+Compute free support barycenter
+-------------
+
+
+
+.. code-block:: python
+
+
+ k = 10 # number of Diracs of the barycenter
+ X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
+ b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
+
+ X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
+
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1)
+ for (x_i, b_i) in zip(measures_locations, measures_weights):
+ color = np.random.randint(low=1, high=10 * N)
+ pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')
+ pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
+ pl.title('Data measures and their barycenter')
+ pl.legend(loc=0)
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_free_support_barycenter_001.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.129 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_free_support_barycenter.py <plot_free_support_barycenter.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_free_support_barycenter.ipynb <plot_free_support_barycenter.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_gromov.ipynb b/docs/source/auto_examples/plot_gromov.ipynb
new file mode 100644
index 0000000..dc1f179
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Gromov-Wasserstein example\n\n\nThis example is designed to show how to use the Gromov-Wassertsein distance\ncomputation in POT.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Erwan Vautier <erwan.vautier@gmail.com>\n# Nicolas Courty <ncourty@irisa.fr>\n#\n# License: MIT License\n\nimport scipy as sp\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nimport ot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Sample two Gaussian distributions (2D and 3D)\n---------------------------------------------\n\nThe Gromov-Wasserstein distance allows to compute distances with samples that\ndo not belong to the same metric space. For demonstration purpose, we sample\ntwo Gaussian distributions in 2- and 3-dimensional spaces.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n_samples = 30 # nb samples\n\nmu_s = np.array([0, 0])\ncov_s = np.array([[1, 0], [0, 1]])\n\nmu_t = np.array([4, 4, 4])\ncov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n\n\nxs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)\nP = sp.linalg.sqrtm(cov_t)\nxt = np.random.randn(n_samples, 3).dot(P) + mu_t"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plotting the distributions\n--------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "fig = pl.figure()\nax1 = fig.add_subplot(121)\nax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\nax2 = fig.add_subplot(122, projection='3d')\nax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute distance kernels, normalize them and then display\n---------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "C1 = sp.spatial.distance.cdist(xs, xs)\nC2 = sp.spatial.distance.cdist(xt, xt)\n\nC1 /= C1.max()\nC2 /= C2.max()\n\npl.figure()\npl.subplot(121)\npl.imshow(C1)\npl.subplot(122)\npl.imshow(C2)\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compute Gromov-Wasserstein plans and distance\n---------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "p = ot.unif(n_samples)\nq = ot.unif(n_samples)\n\ngw0, log0 = ot.gromov.gromov_wasserstein(\n C1, C2, p, q, 'square_loss', verbose=True, log=True)\n\ngw, log = ot.gromov.entropic_gromov_wasserstein(\n C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)\n\n\nprint('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))\nprint('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))\n\n\npl.figure(1, (10, 5))\n\npl.subplot(1, 2, 1)\npl.imshow(gw0, cmap='jet')\npl.title('Gromov Wasserstein')\n\npl.subplot(1, 2, 2)\npl.imshow(gw, cmap='jet')\npl.title('Entropic Gromov Wasserstein')\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_gromov.py b/docs/source/auto_examples/plot_gromov.py
new file mode 100644
index 0000000..deb2f86
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov.py
@@ -0,0 +1,106 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+from mpl_toolkits.mplot3d import Axes3D # noqa
+import ot
+
+#############################################################################
+#
+# Sample two Gaussian distributions (2D and 3D)
+# ---------------------------------------------
+#
+# The Gromov-Wasserstein distance allows to compute distances with samples that
+# do not belong to the same metric space. For demonstration purpose, we sample
+# two Gaussian distributions in 2- and 3-dimensional spaces.
+
+
+n_samples = 30 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4, 4])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+
+#############################################################################
+#
+# Plotting the distributions
+# --------------------------
+
+
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+#############################################################################
+#
+# Compute distance kernels, normalize them and then display
+# ---------------------------------------------------------
+
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+C1 /= C1.max()
+C2 /= C2.max()
+
+pl.figure()
+pl.subplot(121)
+pl.imshow(C1)
+pl.subplot(122)
+pl.imshow(C2)
+pl.show()
+
+#############################################################################
+#
+# Compute Gromov-Wasserstein plans and distance
+# ---------------------------------------------
+
+p = ot.unif(n_samples)
+q = ot.unif(n_samples)
+
+gw0, log0 = ot.gromov.gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', verbose=True, log=True)
+
+gw, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
+
+
+print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
+print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
+
+
+pl.figure(1, (10, 5))
+
+pl.subplot(1, 2, 1)
+pl.imshow(gw0, cmap='jet')
+pl.title('Gromov Wasserstein')
+
+pl.subplot(1, 2, 2)
+pl.imshow(gw, cmap='jet')
+pl.title('Entropic Gromov Wasserstein')
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_gromov.rst b/docs/source/auto_examples/plot_gromov.rst
new file mode 100644
index 0000000..3ed4e11
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov.rst
@@ -0,0 +1,214 @@
+
+
+.. _sphx_glr_auto_examples_plot_gromov.py:
+
+
+==========================
+Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+
+
+
+.. code-block:: python
+
+
+ # Author: Erwan Vautier <erwan.vautier@gmail.com>
+ # Nicolas Courty <ncourty@irisa.fr>
+ #
+ # License: MIT License
+
+ import scipy as sp
+ import numpy as np
+ import matplotlib.pylab as pl
+ from mpl_toolkits.mplot3d import Axes3D # noqa
+ import ot
+
+
+
+
+
+
+
+Sample two Gaussian distributions (2D and 3D)
+---------------------------------------------
+
+The Gromov-Wasserstein distance allows to compute distances with samples that
+do not belong to the same metric space. For demonstration purpose, we sample
+two Gaussian distributions in 2- and 3-dimensional spaces.
+
+
+
+.. code-block:: python
+
+
+
+ n_samples = 30 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ mu_t = np.array([4, 4, 4])
+ cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
+ P = sp.linalg.sqrtm(cov_t)
+ xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+
+
+
+
+
+
+
+Plotting the distributions
+--------------------------
+
+
+
+.. code-block:: python
+
+
+
+ fig = pl.figure()
+ ax1 = fig.add_subplot(121)
+ ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ ax2 = fig.add_subplot(122, projection='3d')
+ ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+ pl.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_gromov_001.png
+ :align: center
+
+
+
+
+Compute distance kernels, normalize them and then display
+---------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+
+ C1 = sp.spatial.distance.cdist(xs, xs)
+ C2 = sp.spatial.distance.cdist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ pl.figure()
+ pl.subplot(121)
+ pl.imshow(C1)
+ pl.subplot(122)
+ pl.imshow(C2)
+ pl.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_gromov_002.png
+ :align: center
+
+
+
+
+Compute Gromov-Wasserstein plans and distance
+---------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ gw0, log0 = ot.gromov.gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', verbose=True, log=True)
+
+ gw, log = ot.gromov.entropic_gromov_wasserstein(
+ C1, C2, p, q, 'square_loss', epsilon=5e-4, log=True, verbose=True)
+
+
+ print('Gromov-Wasserstein distances: ' + str(log0['gw_dist']))
+ print('Entropic Gromov-Wasserstein distances: ' + str(log['gw_dist']))
+
+
+ pl.figure(1, (10, 5))
+
+ pl.subplot(1, 2, 1)
+ pl.imshow(gw0, cmap='jet')
+ pl.title('Gromov Wasserstein')
+
+ pl.subplot(1, 2, 2)
+ pl.imshow(gw, cmap='jet')
+ pl.title('Entropic Gromov Wasserstein')
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_gromov_003.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|4.328711e-02|0.000000e+00
+ 1|2.281369e-02|-8.974178e-01
+ 2|1.843659e-02|-2.374139e-01
+ 3|1.602820e-02|-1.502598e-01
+ 4|1.353712e-02|-1.840179e-01
+ 5|1.285687e-02|-5.290977e-02
+ 6|1.284537e-02|-8.952931e-04
+ 7|1.284525e-02|-8.989584e-06
+ 8|1.284525e-02|-8.989950e-08
+ 9|1.284525e-02|-8.989949e-10
+ It. |Err
+ -------------------
+ 0|7.263293e-02|
+ 10|1.737784e-02|
+ 20|7.783978e-03|
+ 30|3.399419e-07|
+ 40|3.751207e-11|
+ Gromov-Wasserstein distances: 0.012845252089244688
+ Entropic Gromov-Wasserstein distances: 0.013543882352191079
+
+
+**Total running time of the script:** ( 0 minutes 1.916 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_gromov.py <plot_gromov.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_gromov.ipynb <plot_gromov.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_gromov_barycenter.ipynb b/docs/source/auto_examples/plot_gromov_barycenter.ipynb
new file mode 100644
index 0000000..4c2f28f
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov_barycenter.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ],
+ "cell_type": "code"
+ },
+ {
+ "metadata": {},
+ "source": [
+ "\n# Gromov-Wasserstein Barycenter example\n\n\nThis example is designed to show how to use the Gromov-Wasserstein distance\ncomputation in POT.\n\n"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Erwan Vautier <erwan.vautier@gmail.com>\n# Nicolas Courty <ncourty@irisa.fr>\n#\n# License: MIT License\n\n\nimport numpy as np\nimport scipy as sp\n\nimport scipy.ndimage as spi\nimport matplotlib.pylab as pl\nfrom sklearn import manifold\nfrom sklearn.decomposition import PCA\n\nimport ot"
+ ],
+ "cell_type": "code"
+ },
+ {
+ "metadata": {},
+ "source": [
+ "Smacof MDS\n----------\n\nThis function allows to find an embedding of points given a dissimilarity matrix\nthat will be given by the output of the algorithm\n\n"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "def smacof_mds(C, dim, max_iter=3000, eps=1e-9):\n \"\"\"\n Returns an interpolated point cloud following the dissimilarity matrix C\n using SMACOF multidimensional scaling (MDS) in specific dimensionned\n target space\n\n Parameters\n ----------\n C : ndarray, shape (ns, ns)\n dissimilarity matrix\n dim : int\n dimension of the targeted space\n max_iter : int\n Maximum number of iterations of the SMACOF algorithm for a single run\n eps : float\n relative tolerance w.r.t stress to declare converge\n\n Returns\n -------\n npos : ndarray, shape (R, dim)\n Embedded coordinates of the interpolated point cloud (defined with\n one isometry)\n \"\"\"\n\n rng = np.random.RandomState(seed=3)\n\n mds = manifold.MDS(\n dim,\n max_iter=max_iter,\n eps=1e-9,\n dissimilarity='precomputed',\n n_init=1)\n pos = mds.fit(C).embedding_\n\n nmds = manifold.MDS(\n 2,\n max_iter=max_iter,\n eps=1e-9,\n dissimilarity=\"precomputed\",\n random_state=rng,\n n_init=1)\n npos = nmds.fit_transform(C, init=pos)\n\n return npos"
+ ],
+ "cell_type": "code"
+ },
+ {
+ "metadata": {},
+ "source": [
+ "Data preparation\n----------------\n\nThe four distributions are constructed from 4 simple images\n\n"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "def im2mat(I):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\n\n\nsquare = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256\ncross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256\ntriangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256\nstar = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256\n\nshapes = [square, cross, triangle, star]\n\nS = 4\nxs = [[] for i in range(S)]\n\n\nfor nb in range(4):\n for i in range(8):\n for j in range(8):\n if shapes[nb][i, j] < 0.95:\n xs[nb].append([j, 8 - i])\n\nxs = np.array([np.array(xs[0]), np.array(xs[1]),\n np.array(xs[2]), np.array(xs[3])])"
+ ],
+ "cell_type": "code"
+ },
+ {
+ "metadata": {},
+ "source": [
+ "Barycenter computation\n----------------------\n\n"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "ns = [len(xs[s]) for s in range(S)]\nn_samples = 30\n\n\"\"\"Compute all distances matrices for the four shapes\"\"\"\nCs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]\nCs = [cs / cs.max() for cs in Cs]\n\nps = [ot.unif(ns[s]) for s in range(S)]\np = ot.unif(n_samples)\n\n\nlambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]\n\nCt01 = [0 for i in range(2)]\nfor i in range(2):\n Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],\n [ps[0], ps[1]\n ], p, lambdast[i], 'square_loss', # 5e-4,\n max_iter=100, tol=1e-3)\n\nCt02 = [0 for i in range(2)]\nfor i in range(2):\n Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],\n [ps[0], ps[2]\n ], p, lambdast[i], 'square_loss', # 5e-4,\n max_iter=100, tol=1e-3)\n\nCt13 = [0 for i in range(2)]\nfor i in range(2):\n Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],\n [ps[1], ps[3]\n ], p, lambdast[i], 'square_loss', # 5e-4,\n max_iter=100, tol=1e-3)\n\nCt23 = [0 for i in range(2)]\nfor i in range(2):\n Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],\n [ps[2], ps[3]\n ], p, lambdast[i], 'square_loss', # 5e-4,\n max_iter=100, tol=1e-3)"
+ ],
+ "cell_type": "code"
+ },
+ {
+ "metadata": {},
+ "source": [
+ "Visualization\n-------------\n\nThe PCA helps in getting consistency between the rotations\n\n"
+ ],
+ "cell_type": "markdown"
+ },
+ {
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "clf = PCA(n_components=2)\nnpos = [0, 0, 0, 0]\nnpos = [smacof_mds(Cs[s], 2) for s in range(S)]\n\nnpost01 = [0, 0]\nnpost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]\nnpost01 = [clf.fit_transform(npost01[s]) for s in range(2)]\n\nnpost02 = [0, 0]\nnpost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]\nnpost02 = [clf.fit_transform(npost02[s]) for s in range(2)]\n\nnpost13 = [0, 0]\nnpost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]\nnpost13 = [clf.fit_transform(npost13[s]) for s in range(2)]\n\nnpost23 = [0, 0]\nnpost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]\nnpost23 = [clf.fit_transform(npost23[s]) for s in range(2)]\n\n\nfig = pl.figure(figsize=(10, 10))\n\nax1 = pl.subplot2grid((4, 4), (0, 0))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')\n\nax2 = pl.subplot2grid((4, 4), (0, 1))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')\n\nax3 = pl.subplot2grid((4, 4), (0, 2))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')\n\nax4 = pl.subplot2grid((4, 4), (0, 3))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')\n\nax5 = pl.subplot2grid((4, 4), (1, 0))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')\n\nax6 = pl.subplot2grid((4, 4), (1, 3))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')\n\nax7 = pl.subplot2grid((4, 4), (2, 0))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')\n\nax8 = pl.subplot2grid((4, 4), (2, 3))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')\n\nax9 = pl.subplot2grid((4, 4), (3, 0))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')\n\nax10 = pl.subplot2grid((4, 4), (3, 1))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')\n\nax11 = pl.subplot2grid((4, 4), (3, 2))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')\n\nax12 = pl.subplot2grid((4, 4), (3, 3))\npl.xlim((-1, 1))\npl.ylim((-1, 1))\nax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')"
+ ],
+ "cell_type": "code"
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python",
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "nbconvert_exporter": "python",
+ "version": "3.5.2",
+ "pygments_lexer": "ipython3",
+ "file_extension": ".py",
+ "mimetype": "text/x-python"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3",
+ "language": "python"
+ }
+ },
+ "nbformat_minor": 0,
+ "nbformat": 4
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_gromov_barycenter.py b/docs/source/auto_examples/plot_gromov_barycenter.py
new file mode 100644
index 0000000..58fc51a
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov_barycenter.py
@@ -0,0 +1,248 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================
+Gromov-Wasserstein Barycenter example
+=====================================
+
+This example is designed to show how to use the Gromov-Wasserstein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import scipy as sp
+
+import scipy.ndimage as spi
+import matplotlib.pylab as pl
+from sklearn import manifold
+from sklearn.decomposition import PCA
+
+import ot
+
+##############################################################################
+# Smacof MDS
+# ----------
+#
+# This function allows to find an embedding of points given a dissimilarity matrix
+# that will be given by the output of the algorithm
+
+
+def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
+ """
+ Returns an interpolated point cloud following the dissimilarity matrix C
+ using SMACOF multidimensional scaling (MDS) in specific dimensionned
+ target space
+
+ Parameters
+ ----------
+ C : ndarray, shape (ns, ns)
+ dissimilarity matrix
+ dim : int
+ dimension of the targeted space
+ max_iter : int
+ Maximum number of iterations of the SMACOF algorithm for a single run
+ eps : float
+ relative tolerance w.r.t stress to declare converge
+
+ Returns
+ -------
+ npos : ndarray, shape (R, dim)
+ Embedded coordinates of the interpolated point cloud (defined with
+ one isometry)
+ """
+
+ rng = np.random.RandomState(seed=3)
+
+ mds = manifold.MDS(
+ dim,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity='precomputed',
+ n_init=1)
+ pos = mds.fit(C).embedding_
+
+ nmds = manifold.MDS(
+ 2,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity="precomputed",
+ random_state=rng,
+ n_init=1)
+ npos = nmds.fit_transform(C, init=pos)
+
+ return npos
+
+
+##############################################################################
+# Data preparation
+# ----------------
+#
+# The four distributions are constructed from 4 simple images
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
+cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
+star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
+
+shapes = [square, cross, triangle, star]
+
+S = 4
+xs = [[] for i in range(S)]
+
+
+for nb in range(4):
+ for i in range(8):
+ for j in range(8):
+ if shapes[nb][i, j] < 0.95:
+ xs[nb].append([j, 8 - i])
+
+xs = np.array([np.array(xs[0]), np.array(xs[1]),
+ np.array(xs[2]), np.array(xs[3])])
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+
+ns = [len(xs[s]) for s in range(S)]
+n_samples = 30
+
+"""Compute all distances matrices for the four shapes"""
+Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
+Cs = [cs / cs.max() for cs in Cs]
+
+ps = [ot.unif(ns[s]) for s in range(S)]
+p = ot.unif(n_samples)
+
+
+lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
+
+Ct01 = [0 for i in range(2)]
+for i in range(2):
+ Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
+ [ps[0], ps[1]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+Ct02 = [0 for i in range(2)]
+for i in range(2):
+ Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
+ [ps[0], ps[2]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+Ct13 = [0 for i in range(2)]
+for i in range(2):
+ Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
+ [ps[1], ps[3]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+Ct23 = [0 for i in range(2)]
+for i in range(2):
+ Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
+ [ps[2], ps[3]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+
+##############################################################################
+# Visualization
+# -------------
+#
+# The PCA helps in getting consistency between the rotations
+
+
+clf = PCA(n_components=2)
+npos = [0, 0, 0, 0]
+npos = [smacof_mds(Cs[s], 2) for s in range(S)]
+
+npost01 = [0, 0]
+npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
+npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
+
+npost02 = [0, 0]
+npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
+npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
+
+npost13 = [0, 0]
+npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
+npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
+
+npost23 = [0, 0]
+npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
+npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
+
+
+fig = pl.figure(figsize=(10, 10))
+
+ax1 = pl.subplot2grid((4, 4), (0, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
+
+ax2 = pl.subplot2grid((4, 4), (0, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
+
+ax3 = pl.subplot2grid((4, 4), (0, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
+
+ax4 = pl.subplot2grid((4, 4), (0, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
+
+ax5 = pl.subplot2grid((4, 4), (1, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
+
+ax6 = pl.subplot2grid((4, 4), (1, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
+
+ax7 = pl.subplot2grid((4, 4), (2, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
+
+ax8 = pl.subplot2grid((4, 4), (2, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
+
+ax9 = pl.subplot2grid((4, 4), (3, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
+
+ax10 = pl.subplot2grid((4, 4), (3, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
+
+ax11 = pl.subplot2grid((4, 4), (3, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
+
+ax12 = pl.subplot2grid((4, 4), (3, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
diff --git a/docs/source/auto_examples/plot_gromov_barycenter.rst b/docs/source/auto_examples/plot_gromov_barycenter.rst
new file mode 100644
index 0000000..531ee22
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov_barycenter.rst
@@ -0,0 +1,329 @@
+
+
+.. _sphx_glr_auto_examples_plot_gromov_barycenter.py:
+
+
+=====================================
+Gromov-Wasserstein Barycenter example
+=====================================
+
+This example is designed to show how to use the Gromov-Wasserstein distance
+computation in POT.
+
+
+
+.. code-block:: python
+
+
+ # Author: Erwan Vautier <erwan.vautier@gmail.com>
+ # Nicolas Courty <ncourty@irisa.fr>
+ #
+ # License: MIT License
+
+
+ import numpy as np
+ import scipy as sp
+
+ import scipy.ndimage as spi
+ import matplotlib.pylab as pl
+ from sklearn import manifold
+ from sklearn.decomposition import PCA
+
+ import ot
+
+
+
+
+
+
+
+Smacof MDS
+----------
+
+This function allows to find an embedding of points given a dissimilarity matrix
+that will be given by the output of the algorithm
+
+
+
+.. code-block:: python
+
+
+
+ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
+ """
+ Returns an interpolated point cloud following the dissimilarity matrix C
+ using SMACOF multidimensional scaling (MDS) in specific dimensionned
+ target space
+
+ Parameters
+ ----------
+ C : ndarray, shape (ns, ns)
+ dissimilarity matrix
+ dim : int
+ dimension of the targeted space
+ max_iter : int
+ Maximum number of iterations of the SMACOF algorithm for a single run
+ eps : float
+ relative tolerance w.r.t stress to declare converge
+
+ Returns
+ -------
+ npos : ndarray, shape (R, dim)
+ Embedded coordinates of the interpolated point cloud (defined with
+ one isometry)
+ """
+
+ rng = np.random.RandomState(seed=3)
+
+ mds = manifold.MDS(
+ dim,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity='precomputed',
+ n_init=1)
+ pos = mds.fit(C).embedding_
+
+ nmds = manifold.MDS(
+ 2,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity="precomputed",
+ random_state=rng,
+ n_init=1)
+ npos = nmds.fit_transform(C, init=pos)
+
+ return npos
+
+
+
+
+
+
+
+
+Data preparation
+----------------
+
+The four distributions are constructed from 4 simple images
+
+
+
+.. code-block:: python
+
+
+
+ def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+ square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
+ cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
+ triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
+ star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
+
+ shapes = [square, cross, triangle, star]
+
+ S = 4
+ xs = [[] for i in range(S)]
+
+
+ for nb in range(4):
+ for i in range(8):
+ for j in range(8):
+ if shapes[nb][i, j] < 0.95:
+ xs[nb].append([j, 8 - i])
+
+ xs = np.array([np.array(xs[0]), np.array(xs[1]),
+ np.array(xs[2]), np.array(xs[3])])
+
+
+
+
+
+
+
+Barycenter computation
+----------------------
+
+
+
+.. code-block:: python
+
+
+
+ ns = [len(xs[s]) for s in range(S)]
+ n_samples = 30
+
+ """Compute all distances matrices for the four shapes"""
+ Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
+ Cs = [cs / cs.max() for cs in Cs]
+
+ ps = [ot.unif(ns[s]) for s in range(S)]
+ p = ot.unif(n_samples)
+
+
+ lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
+
+ Ct01 = [0 for i in range(2)]
+ for i in range(2):
+ Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
+ [ps[0], ps[1]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+ Ct02 = [0 for i in range(2)]
+ for i in range(2):
+ Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
+ [ps[0], ps[2]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+ Ct13 = [0 for i in range(2)]
+ for i in range(2):
+ Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
+ [ps[1], ps[3]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+ Ct23 = [0 for i in range(2)]
+ for i in range(2):
+ Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
+ [ps[2], ps[3]
+ ], p, lambdast[i], 'square_loss', # 5e-4,
+ max_iter=100, tol=1e-3)
+
+
+
+
+
+
+
+
+Visualization
+-------------
+
+The PCA helps in getting consistency between the rotations
+
+
+
+.. code-block:: python
+
+
+
+ clf = PCA(n_components=2)
+ npos = [0, 0, 0, 0]
+ npos = [smacof_mds(Cs[s], 2) for s in range(S)]
+
+ npost01 = [0, 0]
+ npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
+ npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
+
+ npost02 = [0, 0]
+ npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
+ npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
+
+ npost13 = [0, 0]
+ npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
+ npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
+
+ npost23 = [0, 0]
+ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
+ npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
+
+
+ fig = pl.figure(figsize=(10, 10))
+
+ ax1 = pl.subplot2grid((4, 4), (0, 0))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
+
+ ax2 = pl.subplot2grid((4, 4), (0, 1))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
+
+ ax3 = pl.subplot2grid((4, 4), (0, 2))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
+
+ ax4 = pl.subplot2grid((4, 4), (0, 3))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
+
+ ax5 = pl.subplot2grid((4, 4), (1, 0))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
+
+ ax6 = pl.subplot2grid((4, 4), (1, 3))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
+
+ ax7 = pl.subplot2grid((4, 4), (2, 0))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
+
+ ax8 = pl.subplot2grid((4, 4), (2, 3))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
+
+ ax9 = pl.subplot2grid((4, 4), (3, 0))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
+
+ ax10 = pl.subplot2grid((4, 4), (3, 1))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
+
+ ax11 = pl.subplot2grid((4, 4), (3, 2))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
+
+ ax12 = pl.subplot2grid((4, 4), (3, 3))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_gromov_barycenter_001.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 5.906 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_gromov_barycenter.py <plot_gromov_barycenter.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_gromov_barycenter.ipynb <plot_gromov_barycenter.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_optim_OTreg.ipynb b/docs/source/auto_examples/plot_optim_OTreg.ipynb
new file mode 100644
index 0000000..107c299
--- /dev/null
+++ b/docs/source/auto_examples/plot_optim_OTreg.ipynb
@@ -0,0 +1,144 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Regularized OT with generic solver\n\n\nIllustrates the use of the generic solver for regularized OT with\nuser-designed regularization term. It uses Conditional gradient as in [6] and\ngeneralized Conditional Gradient as proposed in [5][7].\n\n\n[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for\nDomain Adaptation, in IEEE Transactions on Pattern Analysis and Machine\nIntelligence , vol.PP, no.99, pp.1-1.\n\n[6] Ferradans, S., Papadakis, N., Peyr\u00e9, G., & Aujol, J. F. (2014).\nRegularized discrete optimal transport. SIAM Journal on Imaging Sciences,\n7(3), 1853-1882.\n\n[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized\nconditional gradient: analysis of convergence and applications.\narXiv preprint arXiv:1510.06567.\n\n\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\nb = ot.datasets.make_1D_gauss(n, m=60, s=10)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve EMD\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% EMD\n\nG0 = ot.emd(a, b, M)\n\npl.figure(3, figsize=(5, 5))\not.plot.plot1D_mat(a, b, G0, 'OT matrix G0')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve EMD with Frobenius norm regularization\n--------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Example with Frobenius norm regularization\n\n\ndef f(G):\n return 0.5 * np.sum(G**2)\n\n\ndef df(G):\n return G\n\n\nreg = 1e-1\n\nGl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)\n\npl.figure(3)\not.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve EMD with entropic regularization\n--------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Example with entropic regularization\n\n\ndef f(G):\n return np.sum(G * np.log(G))\n\n\ndef df(G):\n return np.log(G) + 1.\n\n\nreg = 1e-3\n\nGe = ot.optim.cg(a, b, M, reg, f, df, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Solve EMD with Frobenius norm + entropic regularization\n-------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "#%% Example with Frobenius norm + entropic regularization with gcg\n\n\ndef f(G):\n return 0.5 * np.sum(G**2)\n\n\ndef df(G):\n return G\n\n\nreg1 = 1e-3\nreg2 = 1e-1\n\nGel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)\n\npl.figure(5, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_optim_OTreg.py b/docs/source/auto_examples/plot_optim_OTreg.py
new file mode 100644
index 0000000..2c58def
--- /dev/null
+++ b/docs/source/auto_examples/plot_optim_OTreg.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+"""
+==================================
+Regularized OT with generic solver
+==================================
+
+Illustrates the use of the generic solver for regularized OT with
+user-designed regularization term. It uses Conditional gradient as in [6] and
+generalized Conditional Gradient as proposed in [5][7].
+
+
+[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for
+Domain Adaptation, in IEEE Transactions on Pattern Analysis and Machine
+Intelligence , vol.PP, no.99, pp.1-1.
+
+[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport. SIAM Journal on Imaging Sciences,
+7(3), 1853-1882.
+
+[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized
+conditional gradient: analysis of convergence and applications.
+arXiv preprint arXiv:1510.06567.
+
+
+
+"""
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# Generate data
+# -------------
+
+#%% parameters
+
+n = 100 # nb bins
+
+# bin positions
+x = np.arange(n, dtype=np.float64)
+
+# Gaussian distributions
+a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+b = ot.datasets.make_1D_gauss(n, m=60, s=10)
+
+# loss matrix
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+##############################################################################
+# Solve EMD
+# ---------
+
+#%% EMD
+
+G0 = ot.emd(a, b, M)
+
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+##############################################################################
+# Solve EMD with Frobenius norm regularization
+# --------------------------------------------
+
+#%% Example with Frobenius norm regularization
+
+
+def f(G):
+ return 0.5 * np.sum(G**2)
+
+
+def df(G):
+ return G
+
+
+reg = 1e-1
+
+Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
+
+pl.figure(3)
+ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')
+
+##############################################################################
+# Solve EMD with entropic regularization
+# --------------------------------------
+
+#%% Example with entropic regularization
+
+
+def f(G):
+ return np.sum(G * np.log(G))
+
+
+def df(G):
+ return np.log(G) + 1.
+
+
+reg = 1e-3
+
+Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')
+
+##############################################################################
+# Solve EMD with Frobenius norm + entropic regularization
+# -------------------------------------------------------
+
+#%% Example with Frobenius norm + entropic regularization with gcg
+
+
+def f(G):
+ return 0.5 * np.sum(G**2)
+
+
+def df(G):
+ return G
+
+
+reg1 = 1e-3
+reg2 = 1e-1
+
+Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)
+
+pl.figure(5, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')
+pl.show()
diff --git a/docs/source/auto_examples/plot_optim_OTreg.rst b/docs/source/auto_examples/plot_optim_OTreg.rst
new file mode 100644
index 0000000..844cba0
--- /dev/null
+++ b/docs/source/auto_examples/plot_optim_OTreg.rst
@@ -0,0 +1,663 @@
+
+
+.. _sphx_glr_auto_examples_plot_optim_OTreg.py:
+
+
+==================================
+Regularized OT with generic solver
+==================================
+
+Illustrates the use of the generic solver for regularized OT with
+user-designed regularization term. It uses Conditional gradient as in [6] and
+generalized Conditional Gradient as proposed in [5][7].
+
+
+[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for
+Domain Adaptation, in IEEE Transactions on Pattern Analysis and Machine
+Intelligence , vol.PP, no.99, pp.1-1.
+
+[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport. SIAM Journal on Imaging Sciences,
+7(3), 1853-1882.
+
+[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized
+conditional gradient: analysis of convergence and applications.
+arXiv preprint arXiv:1510.06567.
+
+
+
+
+
+
+.. code-block:: python
+
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ import ot.plot
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b = ot.datasets.make_1D_gauss(n, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ M /= M.max()
+
+
+
+
+
+
+
+Solve EMD
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% EMD
+
+ G0 = ot.emd(a, b, M)
+
+ pl.figure(3, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_003.png
+ :align: center
+
+
+
+
+Solve EMD with Frobenius norm regularization
+--------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Example with Frobenius norm regularization
+
+
+ def f(G):
+ return 0.5 * np.sum(G**2)
+
+
+ def df(G):
+ return G
+
+
+ reg = 1e-1
+
+ Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
+
+ pl.figure(3)
+ ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_004.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|1.760578e-01|0.000000e+00
+ 1|1.669467e-01|-5.457501e-02
+ 2|1.665639e-01|-2.298130e-03
+ 3|1.664378e-01|-7.572776e-04
+ 4|1.664077e-01|-1.811855e-04
+ 5|1.663912e-01|-9.936787e-05
+ 6|1.663852e-01|-3.555826e-05
+ 7|1.663814e-01|-2.305693e-05
+ 8|1.663785e-01|-1.760450e-05
+ 9|1.663767e-01|-1.078011e-05
+ 10|1.663751e-01|-9.525192e-06
+ 11|1.663737e-01|-8.396466e-06
+ 12|1.663727e-01|-6.086938e-06
+ 13|1.663720e-01|-4.042609e-06
+ 14|1.663713e-01|-4.160914e-06
+ 15|1.663707e-01|-3.823502e-06
+ 16|1.663702e-01|-3.022440e-06
+ 17|1.663697e-01|-3.181249e-06
+ 18|1.663692e-01|-2.698532e-06
+ 19|1.663687e-01|-3.258253e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 20|1.663682e-01|-2.741118e-06
+ 21|1.663678e-01|-2.624135e-06
+ 22|1.663673e-01|-2.645179e-06
+ 23|1.663670e-01|-1.957237e-06
+ 24|1.663666e-01|-2.261541e-06
+ 25|1.663663e-01|-1.851305e-06
+ 26|1.663660e-01|-1.942296e-06
+ 27|1.663657e-01|-2.092896e-06
+ 28|1.663653e-01|-1.924361e-06
+ 29|1.663651e-01|-1.625455e-06
+ 30|1.663648e-01|-1.641123e-06
+ 31|1.663645e-01|-1.566666e-06
+ 32|1.663643e-01|-1.338514e-06
+ 33|1.663641e-01|-1.222711e-06
+ 34|1.663639e-01|-1.221805e-06
+ 35|1.663637e-01|-1.440781e-06
+ 36|1.663634e-01|-1.520091e-06
+ 37|1.663632e-01|-1.288193e-06
+ 38|1.663630e-01|-1.123055e-06
+ 39|1.663628e-01|-1.024487e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 40|1.663627e-01|-1.079606e-06
+ 41|1.663625e-01|-1.172093e-06
+ 42|1.663623e-01|-1.047880e-06
+ 43|1.663621e-01|-1.010577e-06
+ 44|1.663619e-01|-1.064438e-06
+ 45|1.663618e-01|-9.882375e-07
+ 46|1.663616e-01|-8.532647e-07
+ 47|1.663615e-01|-9.930189e-07
+ 48|1.663613e-01|-8.728955e-07
+ 49|1.663612e-01|-9.524214e-07
+ 50|1.663610e-01|-9.088418e-07
+ 51|1.663609e-01|-7.639430e-07
+ 52|1.663608e-01|-6.662611e-07
+ 53|1.663607e-01|-7.133700e-07
+ 54|1.663605e-01|-7.648141e-07
+ 55|1.663604e-01|-6.557516e-07
+ 56|1.663603e-01|-7.304213e-07
+ 57|1.663602e-01|-6.353809e-07
+ 58|1.663601e-01|-7.968279e-07
+ 59|1.663600e-01|-6.367159e-07
+ It. |Loss |Delta loss
+ --------------------------------
+ 60|1.663599e-01|-5.610790e-07
+ 61|1.663598e-01|-5.787466e-07
+ 62|1.663596e-01|-6.937777e-07
+ 63|1.663596e-01|-5.599432e-07
+ 64|1.663595e-01|-5.813048e-07
+ 65|1.663594e-01|-5.724600e-07
+ 66|1.663593e-01|-6.081892e-07
+ 67|1.663592e-01|-5.948732e-07
+ 68|1.663591e-01|-4.941833e-07
+ 69|1.663590e-01|-5.213739e-07
+ 70|1.663589e-01|-5.127355e-07
+ 71|1.663588e-01|-4.349251e-07
+ 72|1.663588e-01|-5.007084e-07
+ 73|1.663587e-01|-4.880265e-07
+ 74|1.663586e-01|-4.931950e-07
+ 75|1.663585e-01|-4.981309e-07
+ 76|1.663584e-01|-3.952959e-07
+ 77|1.663584e-01|-4.544857e-07
+ 78|1.663583e-01|-4.237579e-07
+ 79|1.663582e-01|-4.382386e-07
+ It. |Loss |Delta loss
+ --------------------------------
+ 80|1.663582e-01|-3.646051e-07
+ 81|1.663581e-01|-4.197994e-07
+ 82|1.663580e-01|-4.072764e-07
+ 83|1.663580e-01|-3.994645e-07
+ 84|1.663579e-01|-4.842721e-07
+ 85|1.663578e-01|-3.276486e-07
+ 86|1.663578e-01|-3.737346e-07
+ 87|1.663577e-01|-4.282043e-07
+ 88|1.663576e-01|-4.020937e-07
+ 89|1.663576e-01|-3.431951e-07
+ 90|1.663575e-01|-3.052335e-07
+ 91|1.663575e-01|-3.500538e-07
+ 92|1.663574e-01|-3.063176e-07
+ 93|1.663573e-01|-3.576367e-07
+ 94|1.663573e-01|-3.224681e-07
+ 95|1.663572e-01|-3.673221e-07
+ 96|1.663572e-01|-3.635561e-07
+ 97|1.663571e-01|-3.527236e-07
+ 98|1.663571e-01|-2.788548e-07
+ 99|1.663570e-01|-2.727141e-07
+ It. |Loss |Delta loss
+ --------------------------------
+ 100|1.663570e-01|-3.127278e-07
+ 101|1.663569e-01|-2.637504e-07
+ 102|1.663569e-01|-2.922750e-07
+ 103|1.663568e-01|-3.076454e-07
+ 104|1.663568e-01|-2.911509e-07
+ 105|1.663567e-01|-2.403398e-07
+ 106|1.663567e-01|-2.439790e-07
+ 107|1.663567e-01|-2.634542e-07
+ 108|1.663566e-01|-2.452203e-07
+ 109|1.663566e-01|-2.852991e-07
+ 110|1.663565e-01|-2.165490e-07
+ 111|1.663565e-01|-2.450250e-07
+ 112|1.663564e-01|-2.685294e-07
+ 113|1.663564e-01|-2.821800e-07
+ 114|1.663564e-01|-2.237390e-07
+ 115|1.663563e-01|-1.992842e-07
+ 116|1.663563e-01|-2.166739e-07
+ 117|1.663563e-01|-2.086064e-07
+ 118|1.663562e-01|-2.435945e-07
+ 119|1.663562e-01|-2.292497e-07
+ It. |Loss |Delta loss
+ --------------------------------
+ 120|1.663561e-01|-2.366209e-07
+ 121|1.663561e-01|-2.138746e-07
+ 122|1.663561e-01|-2.009637e-07
+ 123|1.663560e-01|-2.386258e-07
+ 124|1.663560e-01|-1.927442e-07
+ 125|1.663560e-01|-2.081681e-07
+ 126|1.663559e-01|-1.759123e-07
+ 127|1.663559e-01|-1.890771e-07
+ 128|1.663559e-01|-1.971315e-07
+ 129|1.663558e-01|-2.101983e-07
+ 130|1.663558e-01|-2.035645e-07
+ 131|1.663558e-01|-1.984492e-07
+ 132|1.663557e-01|-1.849064e-07
+ 133|1.663557e-01|-1.795703e-07
+ 134|1.663557e-01|-1.624087e-07
+ 135|1.663557e-01|-1.689557e-07
+ 136|1.663556e-01|-1.644308e-07
+ 137|1.663556e-01|-1.618007e-07
+ 138|1.663556e-01|-1.483013e-07
+ 139|1.663555e-01|-1.708771e-07
+ It. |Loss |Delta loss
+ --------------------------------
+ 140|1.663555e-01|-2.013847e-07
+ 141|1.663555e-01|-1.721217e-07
+ 142|1.663554e-01|-2.027911e-07
+ 143|1.663554e-01|-1.764565e-07
+ 144|1.663554e-01|-1.677151e-07
+ 145|1.663554e-01|-1.351982e-07
+ 146|1.663553e-01|-1.423360e-07
+ 147|1.663553e-01|-1.541112e-07
+ 148|1.663553e-01|-1.491601e-07
+ 149|1.663553e-01|-1.466407e-07
+ 150|1.663552e-01|-1.801524e-07
+ 151|1.663552e-01|-1.714107e-07
+ 152|1.663552e-01|-1.491257e-07
+ 153|1.663552e-01|-1.513799e-07
+ 154|1.663551e-01|-1.354539e-07
+ 155|1.663551e-01|-1.233818e-07
+ 156|1.663551e-01|-1.576219e-07
+ 157|1.663551e-01|-1.452791e-07
+ 158|1.663550e-01|-1.262867e-07
+ 159|1.663550e-01|-1.316379e-07
+ It. |Loss |Delta loss
+ --------------------------------
+ 160|1.663550e-01|-1.295447e-07
+ 161|1.663550e-01|-1.283286e-07
+ 162|1.663550e-01|-1.569222e-07
+ 163|1.663549e-01|-1.172942e-07
+ 164|1.663549e-01|-1.399809e-07
+ 165|1.663549e-01|-1.229432e-07
+ 166|1.663549e-01|-1.326191e-07
+ 167|1.663548e-01|-1.209694e-07
+ 168|1.663548e-01|-1.372136e-07
+ 169|1.663548e-01|-1.338395e-07
+ 170|1.663548e-01|-1.416497e-07
+ 171|1.663548e-01|-1.298576e-07
+ 172|1.663547e-01|-1.190590e-07
+ 173|1.663547e-01|-1.167083e-07
+ 174|1.663547e-01|-1.069425e-07
+ 175|1.663547e-01|-1.217780e-07
+ 176|1.663547e-01|-1.140754e-07
+ 177|1.663546e-01|-1.160707e-07
+ 178|1.663546e-01|-1.101798e-07
+ 179|1.663546e-01|-1.114904e-07
+ It. |Loss |Delta loss
+ --------------------------------
+ 180|1.663546e-01|-1.064022e-07
+ 181|1.663546e-01|-9.258231e-08
+ 182|1.663546e-01|-1.213120e-07
+ 183|1.663545e-01|-1.164296e-07
+ 184|1.663545e-01|-1.188762e-07
+ 185|1.663545e-01|-9.394153e-08
+ 186|1.663545e-01|-1.028656e-07
+ 187|1.663545e-01|-1.115348e-07
+ 188|1.663544e-01|-9.768310e-08
+ 189|1.663544e-01|-1.021806e-07
+ 190|1.663544e-01|-1.086303e-07
+ 191|1.663544e-01|-9.879008e-08
+ 192|1.663544e-01|-1.050210e-07
+ 193|1.663544e-01|-1.002463e-07
+ 194|1.663543e-01|-1.062747e-07
+ 195|1.663543e-01|-9.348538e-08
+ 196|1.663543e-01|-7.992512e-08
+ 197|1.663543e-01|-9.558020e-08
+ 198|1.663543e-01|-9.993772e-08
+ 199|1.663543e-01|-8.588499e-08
+ It. |Loss |Delta loss
+ --------------------------------
+ 200|1.663543e-01|-8.737134e-08
+
+
+Solve EMD with entropic regularization
+--------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Example with entropic regularization
+
+
+ def f(G):
+ return np.sum(G * np.log(G))
+
+
+ def df(G):
+ return np.log(G) + 1.
+
+
+ reg = 1e-3
+
+ Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_006.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|1.692289e-01|0.000000e+00
+ 1|1.617643e-01|-4.614437e-02
+ 2|1.612639e-01|-3.102965e-03
+ 3|1.611291e-01|-8.371098e-04
+ 4|1.610468e-01|-5.110558e-04
+ 5|1.610198e-01|-1.672927e-04
+ 6|1.610130e-01|-4.232417e-05
+ 7|1.610090e-01|-2.513455e-05
+ 8|1.610002e-01|-5.443507e-05
+ 9|1.609996e-01|-3.657071e-06
+ 10|1.609948e-01|-2.998735e-05
+ 11|1.609695e-01|-1.569217e-04
+ 12|1.609533e-01|-1.010779e-04
+ 13|1.609520e-01|-8.043897e-06
+ 14|1.609465e-01|-3.415246e-05
+ 15|1.609386e-01|-4.898605e-05
+ 16|1.609324e-01|-3.837052e-05
+ 17|1.609298e-01|-1.617826e-05
+ 18|1.609184e-01|-7.080015e-05
+ 19|1.609083e-01|-6.273206e-05
+ It. |Loss |Delta loss
+ --------------------------------
+ 20|1.608988e-01|-5.940805e-05
+ 21|1.608853e-01|-8.380030e-05
+ 22|1.608844e-01|-5.185045e-06
+ 23|1.608824e-01|-1.279113e-05
+ 24|1.608819e-01|-3.156821e-06
+ 25|1.608783e-01|-2.205746e-05
+ 26|1.608764e-01|-1.189894e-05
+ 27|1.608755e-01|-5.474607e-06
+ 28|1.608737e-01|-1.144227e-05
+ 29|1.608676e-01|-3.775335e-05
+ 30|1.608638e-01|-2.348020e-05
+ 31|1.608627e-01|-6.863136e-06
+ 32|1.608529e-01|-6.110230e-05
+ 33|1.608487e-01|-2.641106e-05
+ 34|1.608409e-01|-4.823638e-05
+ 35|1.608373e-01|-2.256641e-05
+ 36|1.608338e-01|-2.132444e-05
+ 37|1.608310e-01|-1.786649e-05
+ 38|1.608260e-01|-3.103848e-05
+ 39|1.608206e-01|-3.321265e-05
+ It. |Loss |Delta loss
+ --------------------------------
+ 40|1.608201e-01|-3.054747e-06
+ 41|1.608195e-01|-4.198335e-06
+ 42|1.608193e-01|-8.458736e-07
+ 43|1.608159e-01|-2.153759e-05
+ 44|1.608115e-01|-2.738314e-05
+ 45|1.608108e-01|-3.960032e-06
+ 46|1.608081e-01|-1.675447e-05
+ 47|1.608072e-01|-5.976340e-06
+ 48|1.608046e-01|-1.604130e-05
+ 49|1.608020e-01|-1.617036e-05
+ 50|1.608014e-01|-3.957795e-06
+ 51|1.608011e-01|-1.292411e-06
+ 52|1.607998e-01|-8.431795e-06
+ 53|1.607964e-01|-2.127054e-05
+ 54|1.607947e-01|-1.021878e-05
+ 55|1.607947e-01|-3.560621e-07
+ 56|1.607900e-01|-2.929781e-05
+ 57|1.607890e-01|-5.740229e-06
+ 58|1.607858e-01|-2.039550e-05
+ 59|1.607836e-01|-1.319545e-05
+ It. |Loss |Delta loss
+ --------------------------------
+ 60|1.607826e-01|-6.378947e-06
+ 61|1.607808e-01|-1.145102e-05
+ 62|1.607776e-01|-1.941743e-05
+ 63|1.607743e-01|-2.087422e-05
+ 64|1.607741e-01|-1.310249e-06
+ 65|1.607738e-01|-1.682752e-06
+ 66|1.607691e-01|-2.913936e-05
+ 67|1.607671e-01|-1.288855e-05
+ 68|1.607654e-01|-1.002448e-05
+ 69|1.607641e-01|-8.209492e-06
+ 70|1.607632e-01|-5.588467e-06
+ 71|1.607619e-01|-8.050388e-06
+ 72|1.607618e-01|-9.417493e-07
+ 73|1.607598e-01|-1.210509e-05
+ 74|1.607591e-01|-4.392914e-06
+ 75|1.607579e-01|-7.759587e-06
+ 76|1.607574e-01|-2.760280e-06
+ 77|1.607556e-01|-1.146469e-05
+ 78|1.607550e-01|-3.689456e-06
+ 79|1.607550e-01|-4.065631e-08
+ It. |Loss |Delta loss
+ --------------------------------
+ 80|1.607539e-01|-6.555681e-06
+ 81|1.607528e-01|-7.177470e-06
+ 82|1.607527e-01|-5.306068e-07
+ 83|1.607514e-01|-7.816045e-06
+ 84|1.607511e-01|-2.301970e-06
+ 85|1.607504e-01|-4.281072e-06
+ 86|1.607503e-01|-7.821886e-07
+ 87|1.607480e-01|-1.403013e-05
+ 88|1.607480e-01|-1.169298e-08
+ 89|1.607473e-01|-4.235982e-06
+ 90|1.607470e-01|-1.717105e-06
+ 91|1.607470e-01|-6.148402e-09
+ 92|1.607462e-01|-5.396481e-06
+ 93|1.607461e-01|-5.194954e-07
+ 94|1.607450e-01|-6.525707e-06
+ 95|1.607442e-01|-5.332060e-06
+ 96|1.607439e-01|-1.682093e-06
+ 97|1.607437e-01|-1.594796e-06
+ 98|1.607435e-01|-7.923812e-07
+ 99|1.607420e-01|-9.738552e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 100|1.607419e-01|-1.022448e-07
+ 101|1.607419e-01|-4.865999e-07
+ 102|1.607418e-01|-7.092012e-07
+ 103|1.607408e-01|-5.861815e-06
+ 104|1.607402e-01|-3.953266e-06
+ 105|1.607395e-01|-3.969572e-06
+ 106|1.607390e-01|-3.612075e-06
+ 107|1.607377e-01|-7.683735e-06
+ 108|1.607365e-01|-7.777599e-06
+ 109|1.607364e-01|-2.335096e-07
+ 110|1.607364e-01|-4.562036e-07
+ 111|1.607360e-01|-2.089538e-06
+ 112|1.607356e-01|-2.755355e-06
+ 113|1.607349e-01|-4.501960e-06
+ 114|1.607347e-01|-1.160544e-06
+ 115|1.607346e-01|-6.289450e-07
+ 116|1.607345e-01|-2.092146e-07
+ 117|1.607336e-01|-5.990866e-06
+ 118|1.607330e-01|-3.348498e-06
+ 119|1.607328e-01|-1.256222e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 120|1.607320e-01|-5.418353e-06
+ 121|1.607318e-01|-8.296189e-07
+ 122|1.607311e-01|-4.381608e-06
+ 123|1.607310e-01|-8.913901e-07
+ 124|1.607309e-01|-3.808821e-07
+ 125|1.607302e-01|-4.608994e-06
+ 126|1.607294e-01|-5.063777e-06
+ 127|1.607290e-01|-2.532835e-06
+ 128|1.607285e-01|-2.870049e-06
+ 129|1.607284e-01|-4.892812e-07
+ 130|1.607281e-01|-1.760452e-06
+ 131|1.607279e-01|-1.727139e-06
+ 132|1.607275e-01|-2.220706e-06
+ 133|1.607271e-01|-2.516930e-06
+ 134|1.607269e-01|-1.201434e-06
+ 135|1.607269e-01|-2.183459e-09
+ 136|1.607262e-01|-4.223011e-06
+ 137|1.607258e-01|-2.530202e-06
+ 138|1.607258e-01|-1.857260e-07
+ 139|1.607256e-01|-1.401957e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 140|1.607250e-01|-3.242751e-06
+ 141|1.607247e-01|-2.308071e-06
+ 142|1.607247e-01|-4.730700e-08
+ 143|1.607246e-01|-4.240229e-07
+ 144|1.607242e-01|-2.484810e-06
+ 145|1.607238e-01|-2.539206e-06
+ 146|1.607234e-01|-2.535574e-06
+ 147|1.607231e-01|-1.954802e-06
+ 148|1.607228e-01|-1.765447e-06
+ 149|1.607228e-01|-1.620007e-08
+ 150|1.607222e-01|-3.615783e-06
+ 151|1.607222e-01|-8.668516e-08
+ 152|1.607215e-01|-4.000673e-06
+ 153|1.607213e-01|-1.774103e-06
+ 154|1.607213e-01|-6.328834e-09
+ 155|1.607209e-01|-2.418783e-06
+ 156|1.607208e-01|-2.848492e-07
+ 157|1.607207e-01|-8.836043e-07
+ 158|1.607205e-01|-1.192836e-06
+ 159|1.607202e-01|-1.638022e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 160|1.607202e-01|-3.670914e-08
+ 161|1.607197e-01|-3.153709e-06
+ 162|1.607197e-01|-2.419565e-09
+ 163|1.607194e-01|-2.136882e-06
+ 164|1.607194e-01|-1.173754e-09
+ 165|1.607192e-01|-8.169238e-07
+ 166|1.607191e-01|-9.218755e-07
+ 167|1.607189e-01|-9.459255e-07
+ 168|1.607187e-01|-1.294835e-06
+ 169|1.607186e-01|-5.797668e-07
+ 170|1.607186e-01|-4.706272e-08
+ 171|1.607183e-01|-1.753383e-06
+ 172|1.607183e-01|-1.681573e-07
+ 173|1.607183e-01|-2.563971e-10
+
+
+Solve EMD with Frobenius norm + entropic regularization
+-------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Example with Frobenius norm + entropic regularization with gcg
+
+
+ def f(G):
+ return 0.5 * np.sum(G**2)
+
+
+ def df(G):
+ return G
+
+
+ reg1 = 1e-3
+ reg2 = 1e-1
+
+ Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)
+
+ pl.figure(5, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_008.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|1.693084e-01|0.000000e+00
+ 1|1.610121e-01|-5.152589e-02
+ 2|1.609378e-01|-4.622297e-04
+ 3|1.609284e-01|-5.830043e-05
+ 4|1.609284e-01|-1.111407e-12
+
+
+**Total running time of the script:** ( 0 minutes 1.990 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_optim_OTreg.py <plot_optim_OTreg.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_optim_OTreg.ipynb <plot_optim_OTreg.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_classes.ipynb b/docs/source/auto_examples/plot_otda_classes.ipynb
new file mode 100644
index 0000000..643e760
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_classes.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# OT for domain adaptation\n\n\nThis example introduces a domain adaptation in a 2D setting and the 4 OTDA\napproaches currently supported in POT.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport matplotlib.pylab as pl\nimport ot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n_source_samples = 150\nn_target_samples = 150\n\nXs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples)\nXt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# EMD Transport\not_emd = ot.da.EMDTransport()\not_emd.fit(Xs=Xs, Xt=Xt)\n\n# Sinkhorn Transport\not_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn.fit(Xs=Xs, Xt=Xt)\n\n# Sinkhorn Transport with Group lasso regularization\not_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)\not_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)\n\n# Sinkhorn Transport with Group lasso regularization l1l2\not_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,\n verbose=True)\not_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)\n\n# transport source samples onto target samples\ntransp_Xs_emd = ot_emd.transform(Xs=Xs)\ntransp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)\ntransp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)\ntransp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 1 : plots source and target samples\n---------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1, figsize=(10, 5))\npl.subplot(1, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Source samples')\n\npl.subplot(1, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Target samples')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 2 : plot optimal couplings and transported samples\n------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "param_img = {'interpolation': 'nearest'}\n\npl.figure(2, figsize=(15, 8))\npl.subplot(2, 4, 1)\npl.imshow(ot_emd.coupling_, **param_img)\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nEMDTransport')\n\npl.subplot(2, 4, 2)\npl.imshow(ot_sinkhorn.coupling_, **param_img)\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornTransport')\n\npl.subplot(2, 4, 3)\npl.imshow(ot_lpl1.coupling_, **param_img)\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornLpl1Transport')\n\npl.subplot(2, 4, 4)\npl.imshow(ot_l1l2.coupling_, **param_img)\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornL1l2Transport')\n\npl.subplot(2, 4, 5)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.3)\npl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.xticks([])\npl.yticks([])\npl.title('Transported samples\\nEmdTransport')\npl.legend(loc=\"lower left\")\n\npl.subplot(2, 4, 6)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.3)\npl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.xticks([])\npl.yticks([])\npl.title('Transported samples\\nSinkhornTransport')\n\npl.subplot(2, 4, 7)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.3)\npl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.xticks([])\npl.yticks([])\npl.title('Transported samples\\nSinkhornLpl1Transport')\n\npl.subplot(2, 4, 8)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.3)\npl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.xticks([])\npl.yticks([])\npl.title('Transported samples\\nSinkhornL1l2Transport')\npl.tight_layout()\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_classes.py b/docs/source/auto_examples/plot_otda_classes.py
new file mode 100644
index 0000000..c311fbd
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_classes.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+"""
+========================
+OT for domain adaptation
+========================
+
+This example introduces a domain adaptation in a 2D setting and the 4 OTDA
+approaches currently supported in POT.
+
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+
+n_source_samples = 150
+n_target_samples = 150
+
+Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples)
+Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples)
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMD Transport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport with Group lasso regularization
+ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+# Sinkhorn Transport with Group lasso regularization l1l2
+ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,
+ verbose=True)
+ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+# transport source samples onto target samples
+transp_Xs_emd = ot_emd.transform(Xs=Xs)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)
+
+
+##############################################################################
+# Fig 1 : plots source and target samples
+# ---------------------------------------
+
+pl.figure(1, figsize=(10, 5))
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 2 : plot optimal couplings and transported samples
+# ------------------------------------------------------
+
+param_img = {'interpolation': 'nearest'}
+
+pl.figure(2, figsize=(15, 8))
+pl.subplot(2, 4, 1)
+pl.imshow(ot_emd.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nEMDTransport')
+
+pl.subplot(2, 4, 2)
+pl.imshow(ot_sinkhorn.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornTransport')
+
+pl.subplot(2, 4, 3)
+pl.imshow(ot_lpl1.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+pl.subplot(2, 4, 4)
+pl.imshow(ot_l1l2.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornL1l2Transport')
+
+pl.subplot(2, 4, 5)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc="lower left")
+
+pl.subplot(2, 4, 6)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornTransport')
+
+pl.subplot(2, 4, 7)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornLpl1Transport')
+
+pl.subplot(2, 4, 8)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornL1l2Transport')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_classes.rst b/docs/source/auto_examples/plot_otda_classes.rst
new file mode 100644
index 0000000..19756ff
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_classes.rst
@@ -0,0 +1,263 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_classes.py:
+
+
+========================
+OT for domain adaptation
+========================
+
+This example introduces a domain adaptation in a 2D setting and the 4 OTDA
+approaches currently supported in POT.
+
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ n_source_samples = 150
+ n_target_samples = 150
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples)
+
+
+
+
+
+
+
+
+Instantiate the different transport algorithms and fit them
+-----------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # EMD Transport
+ ot_emd = ot.da.EMDTransport()
+ ot_emd.fit(Xs=Xs, Xt=Xt)
+
+ # Sinkhorn Transport
+ ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+ # Sinkhorn Transport with Group lasso regularization
+ ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+ # Sinkhorn Transport with Group lasso regularization l1l2
+ ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,
+ verbose=True)
+ ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+ # transport source samples onto target samples
+ transp_Xs_emd = ot_emd.transform(Xs=Xs)
+ transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+ transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+ transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)
+
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|9.566309e+00|0.000000e+00
+ 1|2.169680e+00|-3.409088e+00
+ 2|1.914989e+00|-1.329986e-01
+ 3|1.860251e+00|-2.942498e-02
+ 4|1.838073e+00|-1.206621e-02
+ 5|1.827064e+00|-6.025122e-03
+ 6|1.820899e+00|-3.386082e-03
+ 7|1.817290e+00|-1.985705e-03
+ 8|1.814644e+00|-1.458223e-03
+ 9|1.812661e+00|-1.093816e-03
+ 10|1.810239e+00|-1.338121e-03
+ 11|1.809100e+00|-6.296940e-04
+ 12|1.807939e+00|-6.420646e-04
+ 13|1.806965e+00|-5.389118e-04
+ 14|1.806822e+00|-7.889599e-05
+ 15|1.806193e+00|-3.482356e-04
+ 16|1.805735e+00|-2.536930e-04
+ 17|1.805321e+00|-2.292667e-04
+ 18|1.804389e+00|-5.170222e-04
+ 19|1.803908e+00|-2.661907e-04
+ It. |Loss |Delta loss
+ --------------------------------
+ 20|1.803696e+00|-1.178279e-04
+
+
+Fig 1 : plots source and target samples
+---------------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(10, 5))
+ pl.subplot(1, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Source samples')
+
+ pl.subplot(1, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Target samples')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_classes_001.png
+ :align: center
+
+
+
+
+Fig 2 : plot optimal couplings and transported samples
+------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ param_img = {'interpolation': 'nearest'}
+
+ pl.figure(2, figsize=(15, 8))
+ pl.subplot(2, 4, 1)
+ pl.imshow(ot_emd.coupling_, **param_img)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nEMDTransport')
+
+ pl.subplot(2, 4, 2)
+ pl.imshow(ot_sinkhorn.coupling_, **param_img)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornTransport')
+
+ pl.subplot(2, 4, 3)
+ pl.imshow(ot_lpl1.coupling_, **param_img)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+ pl.subplot(2, 4, 4)
+ pl.imshow(ot_l1l2.coupling_, **param_img)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornL1l2Transport')
+
+ pl.subplot(2, 4, 5)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+ pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Transported samples\nEmdTransport')
+ pl.legend(loc="lower left")
+
+ pl.subplot(2, 4, 6)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+ pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Transported samples\nSinkhornTransport')
+
+ pl.subplot(2, 4, 7)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+ pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Transported samples\nSinkhornLpl1Transport')
+
+ pl.subplot(2, 4, 8)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+ pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Transported samples\nSinkhornL1l2Transport')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_classes_003.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 1.423 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_classes.py <plot_otda_classes.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_classes.ipynb <plot_otda_classes.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_color_images.ipynb b/docs/source/auto_examples/plot_otda_color_images.ipynb
new file mode 100644
index 0000000..103bdec
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_color_images.ipynb
@@ -0,0 +1,144 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# OT for image color adaptation\n\n\nThis example presents a way of transferring colors between two images\nwith Optimal Transport as introduced in [6]\n\n[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).\nRegularized discrete optimal transport.\nSIAM Journal on Imaging Sciences, 7(3), 1853-1882.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport numpy as np\nfrom scipy import ndimage\nimport matplotlib.pylab as pl\nimport ot\n\n\nr = np.random.RandomState(42)\n\n\ndef im2mat(I):\n \"\"\"Converts an image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\n\n\ndef mat2im(X, shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\n\ndef minmax(I):\n return np.clip(I, 0, 1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Loading images\nI1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256\nI2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256\n\nX1 = im2mat(I1)\nX2 = im2mat(I2)\n\n# training samples\nnb = 1000\nidx1 = r.randint(X1.shape[0], size=(nb,))\nidx2 = r.randint(X2.shape[0], size=(nb,))\n\nXs = X1[idx1, :]\nXt = X2[idx2, :]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot original image\n-------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1, figsize=(6.4, 3))\n\npl.subplot(1, 2, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.imshow(I2)\npl.axis('off')\npl.title('Image 2')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Scatter plot of colors\n----------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(2, figsize=(6.4, 3))\n\npl.subplot(1, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 2')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# EMDTransport\not_emd = ot.da.EMDTransport()\not_emd.fit(Xs=Xs, Xt=Xt)\n\n# SinkhornTransport\not_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn.fit(Xs=Xs, Xt=Xt)\n\n# prediction between images (using out of sample prediction as in [6])\ntransp_Xs_emd = ot_emd.transform(Xs=X1)\ntransp_Xt_emd = ot_emd.inverse_transform(Xt=X2)\n\ntransp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)\ntransp_Xt_sinkhorn = ot_sinkhorn.inverse_transform(Xt=X2)\n\nI1t = minmax(mat2im(transp_Xs_emd, I1.shape))\nI2t = minmax(mat2im(transp_Xt_emd, I2.shape))\n\nI1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))\nI2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot new images\n---------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(3, figsize=(8, 4))\n\npl.subplot(2, 3, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Image 1')\n\npl.subplot(2, 3, 2)\npl.imshow(I1t)\npl.axis('off')\npl.title('Image 1 Adapt')\n\npl.subplot(2, 3, 3)\npl.imshow(I1te)\npl.axis('off')\npl.title('Image 1 Adapt (reg)')\n\npl.subplot(2, 3, 4)\npl.imshow(I2)\npl.axis('off')\npl.title('Image 2')\n\npl.subplot(2, 3, 5)\npl.imshow(I2t)\npl.axis('off')\npl.title('Image 2 Adapt')\n\npl.subplot(2, 3, 6)\npl.imshow(I2te)\npl.axis('off')\npl.title('Image 2 Adapt (reg)')\npl.tight_layout()\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_color_images.py b/docs/source/auto_examples/plot_otda_color_images.py
new file mode 100644
index 0000000..62383a2
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_color_images.py
@@ -0,0 +1,165 @@
+# -*- coding: utf-8 -*-
+"""
+=============================
+OT for image color adaptation
+=============================
+
+This example presents a way of transferring colors between two images
+with Optimal Transport as introduced in [6]
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport.
+SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+from scipy import ndimage
+import matplotlib.pylab as pl
+import ot
+
+
+r = np.random.RandomState(42)
+
+
+def im2mat(I):
+ """Converts an image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+##############################################################################
+# Generate data
+# -------------
+
+# Loading images
+I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+X1 = im2mat(I1)
+X2 = im2mat(I2)
+
+# training samples
+nb = 1000
+idx1 = r.randint(X1.shape[0], size=(nb,))
+idx2 = r.randint(X2.shape[0], size=(nb,))
+
+Xs = X1[idx1, :]
+Xt = X2[idx2, :]
+
+
+##############################################################################
+# Plot original image
+# -------------------
+
+pl.figure(1, figsize=(6.4, 3))
+
+pl.subplot(1, 2, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Image 2')
+
+
+##############################################################################
+# Scatter plot of colors
+# ----------------------
+
+pl.figure(2, figsize=(6.4, 3))
+
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 2')
+pl.tight_layout()
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMDTransport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# SinkhornTransport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# prediction between images (using out of sample prediction as in [6])
+transp_Xs_emd = ot_emd.transform(Xs=X1)
+transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
+
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)
+transp_Xt_sinkhorn = ot_sinkhorn.inverse_transform(Xt=X2)
+
+I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
+I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
+
+I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
+
+
+##############################################################################
+# Plot new images
+# ---------------
+
+pl.figure(3, figsize=(8, 4))
+
+pl.subplot(2, 3, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Image 1')
+
+pl.subplot(2, 3, 2)
+pl.imshow(I1t)
+pl.axis('off')
+pl.title('Image 1 Adapt')
+
+pl.subplot(2, 3, 3)
+pl.imshow(I1te)
+pl.axis('off')
+pl.title('Image 1 Adapt (reg)')
+
+pl.subplot(2, 3, 4)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Image 2')
+
+pl.subplot(2, 3, 5)
+pl.imshow(I2t)
+pl.axis('off')
+pl.title('Image 2 Adapt')
+
+pl.subplot(2, 3, 6)
+pl.imshow(I2te)
+pl.axis('off')
+pl.title('Image 2 Adapt (reg)')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_color_images.rst b/docs/source/auto_examples/plot_otda_color_images.rst
new file mode 100644
index 0000000..ab0406e
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_color_images.rst
@@ -0,0 +1,262 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_color_images.py:
+
+
+=============================
+OT for image color adaptation
+=============================
+
+This example presents a way of transferring colors between two images
+with Optimal Transport as introduced in [6]
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport.
+SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import numpy as np
+ from scipy import ndimage
+ import matplotlib.pylab as pl
+ import ot
+
+
+ r = np.random.RandomState(42)
+
+
+ def im2mat(I):
+ """Converts an image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+ def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+ def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ # Loading images
+ I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+ I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+ X1 = im2mat(I1)
+ X2 = im2mat(I2)
+
+ # training samples
+ nb = 1000
+ idx1 = r.randint(X1.shape[0], size=(nb,))
+ idx2 = r.randint(X2.shape[0], size=(nb,))
+
+ Xs = X1[idx1, :]
+ Xt = X2[idx2, :]
+
+
+
+
+
+
+
+
+Plot original image
+-------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(6.4, 3))
+
+ pl.subplot(1, 2, 1)
+ pl.imshow(I1)
+ pl.axis('off')
+ pl.title('Image 1')
+
+ pl.subplot(1, 2, 2)
+ pl.imshow(I2)
+ pl.axis('off')
+ pl.title('Image 2')
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_color_images_001.png
+ :align: center
+
+
+
+
+Scatter plot of colors
+----------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2, figsize=(6.4, 3))
+
+ pl.subplot(1, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+ pl.axis([0, 1, 0, 1])
+ pl.xlabel('Red')
+ pl.ylabel('Blue')
+ pl.title('Image 1')
+
+ pl.subplot(1, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+ pl.axis([0, 1, 0, 1])
+ pl.xlabel('Red')
+ pl.ylabel('Blue')
+ pl.title('Image 2')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_color_images_003.png
+ :align: center
+
+
+
+
+Instantiate the different transport algorithms and fit them
+-----------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # EMDTransport
+ ot_emd = ot.da.EMDTransport()
+ ot_emd.fit(Xs=Xs, Xt=Xt)
+
+ # SinkhornTransport
+ ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+ # prediction between images (using out of sample prediction as in [6])
+ transp_Xs_emd = ot_emd.transform(Xs=X1)
+ transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
+
+ transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)
+ transp_Xt_sinkhorn = ot_sinkhorn.inverse_transform(Xt=X2)
+
+ I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
+ I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
+
+ I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
+
+
+
+
+
+
+
+
+Plot new images
+---------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(3, figsize=(8, 4))
+
+ pl.subplot(2, 3, 1)
+ pl.imshow(I1)
+ pl.axis('off')
+ pl.title('Image 1')
+
+ pl.subplot(2, 3, 2)
+ pl.imshow(I1t)
+ pl.axis('off')
+ pl.title('Image 1 Adapt')
+
+ pl.subplot(2, 3, 3)
+ pl.imshow(I1te)
+ pl.axis('off')
+ pl.title('Image 1 Adapt (reg)')
+
+ pl.subplot(2, 3, 4)
+ pl.imshow(I2)
+ pl.axis('off')
+ pl.title('Image 2')
+
+ pl.subplot(2, 3, 5)
+ pl.imshow(I2t)
+ pl.axis('off')
+ pl.title('Image 2 Adapt')
+
+ pl.subplot(2, 3, 6)
+ pl.imshow(I2te)
+ pl.axis('off')
+ pl.title('Image 2 Adapt (reg)')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_color_images_005.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 3 minutes 55.541 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_color_images.py <plot_otda_color_images.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_color_images.ipynb <plot_otda_color_images.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_d2.ipynb b/docs/source/auto_examples/plot_otda_d2.ipynb
new file mode 100644
index 0000000..b9002ee
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_d2.ipynb
@@ -0,0 +1,144 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# OT for domain adaptation on empirical distributions\n\n\nThis example introduces a domain adaptation in a 2D setting. It explicits\nthe problem of domain adaptation and introduces some optimal transport\napproaches to solve it.\n\nQuantities such as optimal couplings, greater coupling coefficients and\ntransported samples are represented in order to give a visual understanding\nof what the transport methods are doing.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n_samples_source = 150\nn_samples_target = 150\n\nXs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source)\nXt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target)\n\n# Cost matrix\nM = ot.dist(Xs, Xt, metric='sqeuclidean')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# EMD Transport\not_emd = ot.da.EMDTransport()\not_emd.fit(Xs=Xs, Xt=Xt)\n\n# Sinkhorn Transport\not_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn.fit(Xs=Xs, Xt=Xt)\n\n# Sinkhorn Transport with Group lasso regularization\not_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)\not_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)\n\n# transport source samples onto target samples\ntransp_Xs_emd = ot_emd.transform(Xs=Xs)\ntransp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)\ntransp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 1 : plots source and target samples + matrix of pairwise distance\n---------------------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1, figsize=(10, 10))\npl.subplot(2, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Source samples')\n\npl.subplot(2, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Target samples')\n\npl.subplot(2, 2, 3)\npl.imshow(M, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Matrix of pairwise distances')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 2 : plots optimal couplings for the different methods\n---------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(2, figsize=(10, 6))\n\npl.subplot(2, 3, 1)\npl.imshow(ot_emd.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nEMDTransport')\n\npl.subplot(2, 3, 2)\npl.imshow(ot_sinkhorn.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornTransport')\n\npl.subplot(2, 3, 3)\npl.imshow(ot_lpl1.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornLpl1Transport')\n\npl.subplot(2, 3, 4)\not.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1])\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.title('Main coupling coefficients\\nEMDTransport')\n\npl.subplot(2, 3, 5)\not.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1])\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.title('Main coupling coefficients\\nSinkhornTransport')\n\npl.subplot(2, 3, 6)\not.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1])\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.title('Main coupling coefficients\\nSinkhornLpl1Transport')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 3 : plot transported samples\n--------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# display transported samples\npl.figure(4, figsize=(10, 4))\npl.subplot(1, 3, 1)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nEmdTransport')\npl.legend(loc=0)\npl.xticks([])\npl.yticks([])\n\npl.subplot(1, 3, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nSinkhornTransport')\npl.xticks([])\npl.yticks([])\n\npl.subplot(1, 3, 3)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nSinkhornLpl1Transport')\npl.xticks([])\npl.yticks([])\n\npl.tight_layout()\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_d2.py b/docs/source/auto_examples/plot_otda_d2.py
new file mode 100644
index 0000000..cf22c2f
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_d2.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+"""
+===================================================
+OT for domain adaptation on empirical distributions
+===================================================
+
+This example introduces a domain adaptation in a 2D setting. It explicits
+the problem of domain adaptation and introduces some optimal transport
+approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+import ot.plot
+
+##############################################################################
+# generate data
+# -------------
+
+n_samples_source = 150
+n_samples_target = 150
+
+Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source)
+Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target)
+
+# Cost matrix
+M = ot.dist(Xs, Xt, metric='sqeuclidean')
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMD Transport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport with Group lasso regularization
+ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+# transport source samples onto target samples
+transp_Xs_emd = ot_emd.transform(Xs=Xs)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+
+
+##############################################################################
+# Fig 1 : plots source and target samples + matrix of pairwise distance
+# ---------------------------------------------------------------------
+
+pl.figure(1, figsize=(10, 10))
+pl.subplot(2, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+
+pl.subplot(2, 2, 3)
+pl.imshow(M, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Matrix of pairwise distances')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 2 : plots optimal couplings for the different methods
+# ---------------------------------------------------------
+pl.figure(2, figsize=(10, 6))
+
+pl.subplot(2, 3, 1)
+pl.imshow(ot_emd.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nEMDTransport')
+
+pl.subplot(2, 3, 2)
+pl.imshow(ot_sinkhorn.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornTransport')
+
+pl.subplot(2, 3, 3)
+pl.imshow(ot_lpl1.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+pl.subplot(2, 3, 4)
+ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1])
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.title('Main coupling coefficients\nEMDTransport')
+
+pl.subplot(2, 3, 5)
+ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1])
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.title('Main coupling coefficients\nSinkhornTransport')
+
+pl.subplot(2, 3, 6)
+ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1])
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.title('Main coupling coefficients\nSinkhornLpl1Transport')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 3 : plot transported samples
+# --------------------------------
+
+# display transported samples
+pl.figure(4, figsize=(10, 4))
+pl.subplot(1, 3, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc=0)
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 3, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornTransport')
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 3, 3)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornLpl1Transport')
+pl.xticks([])
+pl.yticks([])
+
+pl.tight_layout()
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_d2.rst b/docs/source/auto_examples/plot_otda_d2.rst
new file mode 100644
index 0000000..80cc34c
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_d2.rst
@@ -0,0 +1,269 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_d2.py:
+
+
+===================================================
+OT for domain adaptation on empirical distributions
+===================================================
+
+This example introduces a domain adaptation in a 2D setting. It explicits
+the problem of domain adaptation and introduces some optimal transport
+approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import matplotlib.pylab as pl
+ import ot
+ import ot.plot
+
+
+
+
+
+
+
+generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ n_samples_source = 150
+ n_samples_target = 150
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target)
+
+ # Cost matrix
+ M = ot.dist(Xs, Xt, metric='sqeuclidean')
+
+
+
+
+
+
+
+
+Instantiate the different transport algorithms and fit them
+-----------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # EMD Transport
+ ot_emd = ot.da.EMDTransport()
+ ot_emd.fit(Xs=Xs, Xt=Xt)
+
+ # Sinkhorn Transport
+ ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+ # Sinkhorn Transport with Group lasso regularization
+ ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+ # transport source samples onto target samples
+ transp_Xs_emd = ot_emd.transform(Xs=Xs)
+ transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+ transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+
+
+
+
+
+
+
+
+Fig 1 : plots source and target samples + matrix of pairwise distance
+---------------------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(10, 10))
+ pl.subplot(2, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Source samples')
+
+ pl.subplot(2, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Target samples')
+
+ pl.subplot(2, 2, 3)
+ pl.imshow(M, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Matrix of pairwise distances')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_d2_001.png
+ :align: center
+
+
+
+
+Fig 2 : plots optimal couplings for the different methods
+---------------------------------------------------------
+
+
+
+.. code-block:: python
+
+ pl.figure(2, figsize=(10, 6))
+
+ pl.subplot(2, 3, 1)
+ pl.imshow(ot_emd.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nEMDTransport')
+
+ pl.subplot(2, 3, 2)
+ pl.imshow(ot_sinkhorn.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornTransport')
+
+ pl.subplot(2, 3, 3)
+ pl.imshow(ot_lpl1.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+ pl.subplot(2, 3, 4)
+ ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1])
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Main coupling coefficients\nEMDTransport')
+
+ pl.subplot(2, 3, 5)
+ ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1])
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Main coupling coefficients\nSinkhornTransport')
+
+ pl.subplot(2, 3, 6)
+ ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1])
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Main coupling coefficients\nSinkhornLpl1Transport')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_d2_003.png
+ :align: center
+
+
+
+
+Fig 3 : plot transported samples
+--------------------------------
+
+
+
+.. code-block:: python
+
+
+ # display transported samples
+ pl.figure(4, figsize=(10, 4))
+ pl.subplot(1, 3, 1)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nEmdTransport')
+ pl.legend(loc=0)
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.subplot(1, 3, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nSinkhornTransport')
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.subplot(1, 3, 3)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nSinkhornLpl1Transport')
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.tight_layout()
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_d2_006.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 35.515 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_d2.py <plot_otda_d2.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_d2.ipynb <plot_otda_d2.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_linear_mapping.ipynb b/docs/source/auto_examples/plot_otda_linear_mapping.ipynb
new file mode 100644
index 0000000..027b6cb
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_linear_mapping.ipynb
@@ -0,0 +1,180 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Linear OT mapping estimation\n\n\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport pylab as pl\nimport ot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n = 1000\nd = 2\nsigma = .1\n\n# source samples\nangles = np.random.rand(n, 1) * 2 * np.pi\nxs = np.concatenate((np.sin(angles), np.cos(angles)),\n axis=1) + sigma * np.random.randn(n, 2)\nxs[:n // 2, 1] += 2\n\n\n# target samples\nanglet = np.random.rand(n, 1) * 2 * np.pi\nxt = np.concatenate((np.sin(anglet), np.cos(anglet)),\n axis=1) + sigma * np.random.randn(n, 2)\nxt[:n // 2, 1] += 2\n\n\nA = np.array([[1.5, .7], [.7, 1.5]])\nb = np.array([[4, 2]])\nxt = xt.dot(A) + b"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1, (5, 5))\npl.plot(xs[:, 0], xs[:, 1], '+')\npl.plot(xt[:, 0], xt[:, 1], 'o')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Estimate linear mapping and transport\n-------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "Ae, be = ot.da.OT_mapping_linear(xs, xt)\n\nxst = xs.dot(Ae) + be"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot transported samples\n------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1, (5, 5))\npl.clf()\npl.plot(xs[:, 0], xs[:, 1], '+')\npl.plot(xt[:, 0], xt[:, 1], 'o')\npl.plot(xst[:, 0], xst[:, 1], '+')\n\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Load image data\n---------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "def im2mat(I):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\n\n\ndef mat2im(X, shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\n\ndef minmax(I):\n return np.clip(I, 0, 1)\n\n\n# Loading images\nI1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256\nI2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256\n\n\nX1 = im2mat(I1)\nX2 = im2mat(I2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Estimate mapping and adapt\n----------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "mapping = ot.da.LinearTransport()\n\nmapping.fit(Xs=X1, Xt=X2)\n\n\nxst = mapping.transform(Xs=X1)\nxts = mapping.inverse_transform(Xt=X2)\n\nI1t = minmax(mat2im(xst, I1.shape))\nI2t = minmax(mat2im(xts, I2.shape))\n\n# %%"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot transformed images\n-----------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(2, figsize=(10, 7))\n\npl.subplot(2, 2, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Im. 1')\n\npl.subplot(2, 2, 2)\npl.imshow(I2)\npl.axis('off')\npl.title('Im. 2')\n\npl.subplot(2, 2, 3)\npl.imshow(I1t)\npl.axis('off')\npl.title('Mapping Im. 1')\n\npl.subplot(2, 2, 4)\npl.imshow(I2t)\npl.axis('off')\npl.title('Inverse mapping Im. 2')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_linear_mapping.py b/docs/source/auto_examples/plot_otda_linear_mapping.py
new file mode 100644
index 0000000..c65bd4f
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_linear_mapping.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+============================
+Linear OT mapping estimation
+============================
+
+
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import numpy as np
+import pylab as pl
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+
+n = 1000
+d = 2
+sigma = .1
+
+# source samples
+angles = np.random.rand(n, 1) * 2 * np.pi
+xs = np.concatenate((np.sin(angles), np.cos(angles)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xs[:n // 2, 1] += 2
+
+
+# target samples
+anglet = np.random.rand(n, 1) * 2 * np.pi
+xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
+ axis=1) + sigma * np.random.randn(n, 2)
+xt[:n // 2, 1] += 2
+
+
+A = np.array([[1.5, .7], [.7, 1.5]])
+b = np.array([[4, 2]])
+xt = xt.dot(A) + b
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1, (5, 5))
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+
+
+##############################################################################
+# Estimate linear mapping and transport
+# -------------------------------------
+
+Ae, be = ot.da.OT_mapping_linear(xs, xt)
+
+xst = xs.dot(Ae) + be
+
+
+##############################################################################
+# Plot transported samples
+# ------------------------
+
+pl.figure(1, (5, 5))
+pl.clf()
+pl.plot(xs[:, 0], xs[:, 1], '+')
+pl.plot(xt[:, 0], xt[:, 1], 'o')
+pl.plot(xst[:, 0], xst[:, 1], '+')
+
+pl.show()
+
+##############################################################################
+# Load image data
+# ---------------
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+# Loading images
+I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+
+X1 = im2mat(I1)
+X2 = im2mat(I2)
+
+##############################################################################
+# Estimate mapping and adapt
+# ----------------------------
+
+mapping = ot.da.LinearTransport()
+
+mapping.fit(Xs=X1, Xt=X2)
+
+
+xst = mapping.transform(Xs=X1)
+xts = mapping.inverse_transform(Xt=X2)
+
+I1t = minmax(mat2im(xst, I1.shape))
+I2t = minmax(mat2im(xts, I2.shape))
+
+# %%
+
+
+##############################################################################
+# Plot transformed images
+# -----------------------
+
+pl.figure(2, figsize=(10, 7))
+
+pl.subplot(2, 2, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Im. 1')
+
+pl.subplot(2, 2, 2)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Im. 2')
+
+pl.subplot(2, 2, 3)
+pl.imshow(I1t)
+pl.axis('off')
+pl.title('Mapping Im. 1')
+
+pl.subplot(2, 2, 4)
+pl.imshow(I2t)
+pl.axis('off')
+pl.title('Inverse mapping Im. 2')
diff --git a/docs/source/auto_examples/plot_otda_linear_mapping.rst b/docs/source/auto_examples/plot_otda_linear_mapping.rst
new file mode 100644
index 0000000..8e2e0cf
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_linear_mapping.rst
@@ -0,0 +1,260 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_linear_mapping.py:
+
+
+============================
+Linear OT mapping estimation
+============================
+
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import pylab as pl
+ import ot
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ n = 1000
+ d = 2
+ sigma = .1
+
+ # source samples
+ angles = np.random.rand(n, 1) * 2 * np.pi
+ xs = np.concatenate((np.sin(angles), np.cos(angles)),
+ axis=1) + sigma * np.random.randn(n, 2)
+ xs[:n // 2, 1] += 2
+
+
+ # target samples
+ anglet = np.random.rand(n, 1) * 2 * np.pi
+ xt = np.concatenate((np.sin(anglet), np.cos(anglet)),
+ axis=1) + sigma * np.random.randn(n, 2)
+ xt[:n // 2, 1] += 2
+
+
+ A = np.array([[1.5, .7], [.7, 1.5]])
+ b = np.array([[4, 2]])
+ xt = xt.dot(A) + b
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, (5, 5))
+ pl.plot(xs[:, 0], xs[:, 1], '+')
+ pl.plot(xt[:, 0], xt[:, 1], 'o')
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_linear_mapping_001.png
+ :align: center
+
+
+
+
+Estimate linear mapping and transport
+-------------------------------------
+
+
+
+.. code-block:: python
+
+
+ Ae, be = ot.da.OT_mapping_linear(xs, xt)
+
+ xst = xs.dot(Ae) + be
+
+
+
+
+
+
+
+
+Plot transported samples
+------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, (5, 5))
+ pl.clf()
+ pl.plot(xs[:, 0], xs[:, 1], '+')
+ pl.plot(xt[:, 0], xt[:, 1], 'o')
+ pl.plot(xst[:, 0], xst[:, 1], '+')
+
+ pl.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_linear_mapping_002.png
+ :align: center
+
+
+
+
+Load image data
+---------------
+
+
+
+.. code-block:: python
+
+
+
+ def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+ def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+ def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+ # Loading images
+ I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+ I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+
+ X1 = im2mat(I1)
+ X2 = im2mat(I2)
+
+
+
+
+
+
+
+Estimate mapping and adapt
+----------------------------
+
+
+
+.. code-block:: python
+
+
+ mapping = ot.da.LinearTransport()
+
+ mapping.fit(Xs=X1, Xt=X2)
+
+
+ xst = mapping.transform(Xs=X1)
+ xts = mapping.inverse_transform(Xt=X2)
+
+ I1t = minmax(mat2im(xst, I1.shape))
+ I2t = minmax(mat2im(xts, I2.shape))
+
+ # %%
+
+
+
+
+
+
+
+
+Plot transformed images
+-----------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2, figsize=(10, 7))
+
+ pl.subplot(2, 2, 1)
+ pl.imshow(I1)
+ pl.axis('off')
+ pl.title('Im. 1')
+
+ pl.subplot(2, 2, 2)
+ pl.imshow(I2)
+ pl.axis('off')
+ pl.title('Im. 2')
+
+ pl.subplot(2, 2, 3)
+ pl.imshow(I1t)
+ pl.axis('off')
+ pl.title('Mapping Im. 1')
+
+ pl.subplot(2, 2, 4)
+ pl.imshow(I2t)
+ pl.axis('off')
+ pl.title('Inverse mapping Im. 2')
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_linear_mapping_004.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.635 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_linear_mapping.py <plot_otda_linear_mapping.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_linear_mapping.ipynb <plot_otda_linear_mapping.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_mapping.ipynb b/docs/source/auto_examples/plot_otda_mapping.ipynb
new file mode 100644
index 0000000..898466d
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping.ipynb
@@ -0,0 +1,126 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# OT mapping estimation for domain adaptation\n\n\nThis example presents how to use MappingTransport to estimate at the same\ntime both the coupling transport and approximate the transport map with either\na linear or a kernelized mapping as introduced in [8].\n\n[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,\n \"Mapping estimation for discrete optimal transport\",\n Neural Information Processing Systems (NIPS), 2016.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n_source_samples = 100\nn_target_samples = 100\ntheta = 2 * np.pi / 20\nnoise_level = 0.1\n\nXs, ys = ot.datasets.make_data_classif(\n 'gaussrot', n_source_samples, nz=noise_level)\nXs_new, _ = ot.datasets.make_data_classif(\n 'gaussrot', n_source_samples, nz=noise_level)\nXt, yt = ot.datasets.make_data_classif(\n 'gaussrot', n_target_samples, theta=theta, nz=noise_level)\n\n# one of the target mode changes its variance (no linear mapping)\nXt[yt == 2] *= 3\nXt = Xt + 4"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot data\n---------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1, (10, 5))\npl.clf()\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.legend(loc=0)\npl.title('Source and target distributions')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# MappingTransport with linear kernel\not_mapping_linear = ot.da.MappingTransport(\n kernel=\"linear\", mu=1e0, eta=1e-8, bias=True,\n max_iter=20, verbose=True)\n\not_mapping_linear.fit(Xs=Xs, Xt=Xt)\n\n# for original source samples, transform applies barycentric mapping\ntransp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)\n\n# for out of source samples, transform applies the linear mapping\ntransp_Xs_linear_new = ot_mapping_linear.transform(Xs=Xs_new)\n\n\n# MappingTransport with gaussian kernel\not_mapping_gaussian = ot.da.MappingTransport(\n kernel=\"gaussian\", eta=1e-5, mu=1e-1, bias=True, sigma=1,\n max_iter=10, verbose=True)\not_mapping_gaussian.fit(Xs=Xs, Xt=Xt)\n\n# for original source samples, transform applies barycentric mapping\ntransp_Xs_gaussian = ot_mapping_gaussian.transform(Xs=Xs)\n\n# for out of source samples, transform applies the gaussian mapping\ntransp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot transported samples\n------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(2)\npl.clf()\npl.subplot(2, 2, 1)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=.2)\npl.scatter(transp_Xs_linear[:, 0], transp_Xs_linear[:, 1], c=ys, marker='+',\n label='Mapped source samples')\npl.title(\"Bary. mapping (linear)\")\npl.legend(loc=0)\n\npl.subplot(2, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=.2)\npl.scatter(transp_Xs_linear_new[:, 0], transp_Xs_linear_new[:, 1],\n c=ys, marker='+', label='Learned mapping')\npl.title(\"Estim. mapping (linear)\")\n\npl.subplot(2, 2, 3)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=.2)\npl.scatter(transp_Xs_gaussian[:, 0], transp_Xs_gaussian[:, 1], c=ys,\n marker='+', label='barycentric mapping')\npl.title(\"Bary. mapping (kernel)\")\n\npl.subplot(2, 2, 4)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=.2)\npl.scatter(transp_Xs_gaussian_new[:, 0], transp_Xs_gaussian_new[:, 1], c=ys,\n marker='+', label='Learned mapping')\npl.title(\"Estim. mapping (kernel)\")\npl.tight_layout()\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_mapping.py b/docs/source/auto_examples/plot_otda_mapping.py
new file mode 100644
index 0000000..5880adf
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+"""
+===========================================
+OT mapping estimation for domain adaptation
+===========================================
+
+This example presents how to use MappingTransport to estimate at the same
+time both the coupling transport and approximate the transport map with either
+a linear or a kernelized mapping as introduced in [8].
+
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+
+n_source_samples = 100
+n_target_samples = 100
+theta = 2 * np.pi / 20
+noise_level = 0.1
+
+Xs, ys = ot.datasets.make_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+Xs_new, _ = ot.datasets.make_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+Xt, yt = ot.datasets.make_data_classif(
+ 'gaussrot', n_target_samples, theta=theta, nz=noise_level)
+
+# one of the target mode changes its variance (no linear mapping)
+Xt[yt == 2] *= 3
+Xt = Xt + 4
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1, (10, 5))
+pl.clf()
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# MappingTransport with linear kernel
+ot_mapping_linear = ot.da.MappingTransport(
+ kernel="linear", mu=1e0, eta=1e-8, bias=True,
+ max_iter=20, verbose=True)
+
+ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+# for original source samples, transform applies barycentric mapping
+transp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)
+
+# for out of source samples, transform applies the linear mapping
+transp_Xs_linear_new = ot_mapping_linear.transform(Xs=Xs_new)
+
+
+# MappingTransport with gaussian kernel
+ot_mapping_gaussian = ot.da.MappingTransport(
+ kernel="gaussian", eta=1e-5, mu=1e-1, bias=True, sigma=1,
+ max_iter=10, verbose=True)
+ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+# for original source samples, transform applies barycentric mapping
+transp_Xs_gaussian = ot_mapping_gaussian.transform(Xs=Xs)
+
+# for out of source samples, transform applies the gaussian mapping
+transp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new)
+
+
+##############################################################################
+# Plot transported samples
+# ------------------------
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 2, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_linear[:, 0], transp_Xs_linear[:, 1], c=ys, marker='+',
+ label='Mapped source samples')
+pl.title("Bary. mapping (linear)")
+pl.legend(loc=0)
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_linear_new[:, 0], transp_Xs_linear_new[:, 1],
+ c=ys, marker='+', label='Learned mapping')
+pl.title("Estim. mapping (linear)")
+
+pl.subplot(2, 2, 3)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_gaussian[:, 0], transp_Xs_gaussian[:, 1], c=ys,
+ marker='+', label='barycentric mapping')
+pl.title("Bary. mapping (kernel)")
+
+pl.subplot(2, 2, 4)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_gaussian_new[:, 0], transp_Xs_gaussian_new[:, 1], c=ys,
+ marker='+', label='Learned mapping')
+pl.title("Estim. mapping (kernel)")
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_mapping.rst b/docs/source/auto_examples/plot_otda_mapping.rst
new file mode 100644
index 0000000..1d95fc6
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping.rst
@@ -0,0 +1,235 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_mapping.py:
+
+
+===========================================
+OT mapping estimation for domain adaptation
+===========================================
+
+This example presents how to use MappingTransport to estimate at the same
+time both the coupling transport and approximate the transport map with either
+a linear or a kernelized mapping as introduced in [8].
+
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ n_source_samples = 100
+ n_target_samples = 100
+ theta = 2 * np.pi / 20
+ noise_level = 0.1
+
+ Xs, ys = ot.datasets.make_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+ Xs_new, _ = ot.datasets.make_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+ Xt, yt = ot.datasets.make_data_classif(
+ 'gaussrot', n_target_samples, theta=theta, nz=noise_level)
+
+ # one of the target mode changes its variance (no linear mapping)
+ Xt[yt == 2] *= 3
+ Xt = Xt + 4
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, (10, 5))
+ pl.clf()
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.legend(loc=0)
+ pl.title('Source and target distributions')
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_001.png
+ :align: center
+
+
+
+
+Instantiate the different transport algorithms and fit them
+-----------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # MappingTransport with linear kernel
+ ot_mapping_linear = ot.da.MappingTransport(
+ kernel="linear", mu=1e0, eta=1e-8, bias=True,
+ max_iter=20, verbose=True)
+
+ ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+ # for original source samples, transform applies barycentric mapping
+ transp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)
+
+ # for out of source samples, transform applies the linear mapping
+ transp_Xs_linear_new = ot_mapping_linear.transform(Xs=Xs_new)
+
+
+ # MappingTransport with gaussian kernel
+ ot_mapping_gaussian = ot.da.MappingTransport(
+ kernel="gaussian", eta=1e-5, mu=1e-1, bias=True, sigma=1,
+ max_iter=10, verbose=True)
+ ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+ # for original source samples, transform applies barycentric mapping
+ transp_Xs_gaussian = ot_mapping_gaussian.transform(Xs=Xs)
+
+ # for out of source samples, transform applies the gaussian mapping
+ transp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new)
+
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|4.299275e+03|0.000000e+00
+ 1|4.290443e+03|-2.054271e-03
+ 2|4.290040e+03|-9.389994e-05
+ 3|4.289876e+03|-3.830707e-05
+ 4|4.289783e+03|-2.157428e-05
+ 5|4.289724e+03|-1.390941e-05
+ 6|4.289706e+03|-4.051054e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|4.326465e+02|0.000000e+00
+ 1|4.282533e+02|-1.015416e-02
+ 2|4.279473e+02|-7.145955e-04
+ 3|4.277941e+02|-3.580104e-04
+ 4|4.277069e+02|-2.039229e-04
+ 5|4.276462e+02|-1.418698e-04
+ 6|4.276011e+02|-1.054172e-04
+ 7|4.275663e+02|-8.145802e-05
+ 8|4.275405e+02|-6.028774e-05
+ 9|4.275191e+02|-5.005886e-05
+ 10|4.275019e+02|-4.021935e-05
+
+
+Plot transported samples
+------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2)
+ pl.clf()
+ pl.subplot(2, 2, 1)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+ pl.scatter(transp_Xs_linear[:, 0], transp_Xs_linear[:, 1], c=ys, marker='+',
+ label='Mapped source samples')
+ pl.title("Bary. mapping (linear)")
+ pl.legend(loc=0)
+
+ pl.subplot(2, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+ pl.scatter(transp_Xs_linear_new[:, 0], transp_Xs_linear_new[:, 1],
+ c=ys, marker='+', label='Learned mapping')
+ pl.title("Estim. mapping (linear)")
+
+ pl.subplot(2, 2, 3)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+ pl.scatter(transp_Xs_gaussian[:, 0], transp_Xs_gaussian[:, 1], c=ys,
+ marker='+', label='barycentric mapping')
+ pl.title("Bary. mapping (kernel)")
+
+ pl.subplot(2, 2, 4)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+ pl.scatter(transp_Xs_gaussian_new[:, 0], transp_Xs_gaussian_new[:, 1], c=ys,
+ marker='+', label='Learned mapping')
+ pl.title("Estim. mapping (kernel)")
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_003.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.795 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_mapping.py <plot_otda_mapping.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_mapping.ipynb <plot_otda_mapping.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_mapping_colors_images.ipynb b/docs/source/auto_examples/plot_otda_mapping_colors_images.ipynb
new file mode 100644
index 0000000..baffef4
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping_colors_images.ipynb
@@ -0,0 +1,144 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# OT for image color adaptation with mapping estimation\n\n\nOT for domain adaptation with image color adaptation [6] with mapping\nestimation [8].\n\n[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized\n discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3),\n 1853-1882.\n[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, \"Mapping estimation for\n discrete optimal transport\", Neural Information Processing Systems (NIPS),\n 2016.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport numpy as np\nfrom scipy import ndimage\nimport matplotlib.pylab as pl\nimport ot\n\nr = np.random.RandomState(42)\n\n\ndef im2mat(I):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\n\n\ndef mat2im(X, shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\n\ndef minmax(I):\n return np.clip(I, 0, 1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Loading images\nI1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256\nI2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256\n\n\nX1 = im2mat(I1)\nX2 = im2mat(I2)\n\n# training samples\nnb = 1000\nidx1 = r.randint(X1.shape[0], size=(nb,))\nidx2 = r.randint(X2.shape[0], size=(nb,))\n\nXs = X1[idx1, :]\nXt = X2[idx2, :]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Domain adaptation for pixel distribution transfer\n-------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# EMDTransport\not_emd = ot.da.EMDTransport()\not_emd.fit(Xs=Xs, Xt=Xt)\ntransp_Xs_emd = ot_emd.transform(Xs=X1)\nImage_emd = minmax(mat2im(transp_Xs_emd, I1.shape))\n\n# SinkhornTransport\not_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn.fit(Xs=Xs, Xt=Xt)\ntransp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)\nImage_sinkhorn = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))\n\not_mapping_linear = ot.da.MappingTransport(\n mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True)\not_mapping_linear.fit(Xs=Xs, Xt=Xt)\n\nX1tl = ot_mapping_linear.transform(Xs=X1)\nImage_mapping_linear = minmax(mat2im(X1tl, I1.shape))\n\not_mapping_gaussian = ot.da.MappingTransport(\n mu=1e0, eta=1e-2, sigma=1, bias=False, max_iter=10, verbose=True)\not_mapping_gaussian.fit(Xs=Xs, Xt=Xt)\n\nX1tn = ot_mapping_gaussian.transform(Xs=X1) # use the estimated mapping\nImage_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot original images\n--------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1, figsize=(6.4, 3))\npl.subplot(1, 2, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.imshow(I2)\npl.axis('off')\npl.title('Image 2')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot pixel values distribution\n------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(2, figsize=(6.4, 5))\n\npl.subplot(1, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 2')\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot transformed images\n-----------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(2, figsize=(10, 5))\n\npl.subplot(2, 3, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Im. 1')\n\npl.subplot(2, 3, 4)\npl.imshow(I2)\npl.axis('off')\npl.title('Im. 2')\n\npl.subplot(2, 3, 2)\npl.imshow(Image_emd)\npl.axis('off')\npl.title('EmdTransport')\n\npl.subplot(2, 3, 5)\npl.imshow(Image_sinkhorn)\npl.axis('off')\npl.title('SinkhornTransport')\n\npl.subplot(2, 3, 3)\npl.imshow(Image_mapping_linear)\npl.axis('off')\npl.title('MappingTransport (linear)')\n\npl.subplot(2, 3, 6)\npl.imshow(Image_mapping_gaussian)\npl.axis('off')\npl.title('MappingTransport (gaussian)')\npl.tight_layout()\n\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_mapping_colors_images.py b/docs/source/auto_examples/plot_otda_mapping_colors_images.py
new file mode 100644
index 0000000..a20eca8
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping_colors_images.py
@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================================
+OT for image color adaptation with mapping estimation
+=====================================================
+
+OT for domain adaptation with image color adaptation [6] with mapping
+estimation [8].
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized
+ discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3),
+ 1853-1882.
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
+ discrete optimal transport", Neural Information Processing Systems (NIPS),
+ 2016.
+
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+from scipy import ndimage
+import matplotlib.pylab as pl
+import ot
+
+r = np.random.RandomState(42)
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+##############################################################################
+# Generate data
+# -------------
+
+# Loading images
+I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+
+X1 = im2mat(I1)
+X2 = im2mat(I2)
+
+# training samples
+nb = 1000
+idx1 = r.randint(X1.shape[0], size=(nb,))
+idx2 = r.randint(X2.shape[0], size=(nb,))
+
+Xs = X1[idx1, :]
+Xt = X2[idx2, :]
+
+
+##############################################################################
+# Domain adaptation for pixel distribution transfer
+# -------------------------------------------------
+
+# EMDTransport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+transp_Xs_emd = ot_emd.transform(Xs=X1)
+Image_emd = minmax(mat2im(transp_Xs_emd, I1.shape))
+
+# SinkhornTransport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)
+Image_sinkhorn = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+
+ot_mapping_linear = ot.da.MappingTransport(
+ mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True)
+ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+X1tl = ot_mapping_linear.transform(Xs=X1)
+Image_mapping_linear = minmax(mat2im(X1tl, I1.shape))
+
+ot_mapping_gaussian = ot.da.MappingTransport(
+ mu=1e0, eta=1e-2, sigma=1, bias=False, max_iter=10, verbose=True)
+ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+X1tn = ot_mapping_gaussian.transform(Xs=X1) # use the estimated mapping
+Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
+
+
+##############################################################################
+# Plot original images
+# --------------------
+
+pl.figure(1, figsize=(6.4, 3))
+pl.subplot(1, 2, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Image 2')
+pl.tight_layout()
+
+
+##############################################################################
+# Plot pixel values distribution
+# ------------------------------
+
+pl.figure(2, figsize=(6.4, 5))
+
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 2')
+pl.tight_layout()
+
+
+##############################################################################
+# Plot transformed images
+# -----------------------
+
+pl.figure(2, figsize=(10, 5))
+
+pl.subplot(2, 3, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Im. 1')
+
+pl.subplot(2, 3, 4)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Im. 2')
+
+pl.subplot(2, 3, 2)
+pl.imshow(Image_emd)
+pl.axis('off')
+pl.title('EmdTransport')
+
+pl.subplot(2, 3, 5)
+pl.imshow(Image_sinkhorn)
+pl.axis('off')
+pl.title('SinkhornTransport')
+
+pl.subplot(2, 3, 3)
+pl.imshow(Image_mapping_linear)
+pl.axis('off')
+pl.title('MappingTransport (linear)')
+
+pl.subplot(2, 3, 6)
+pl.imshow(Image_mapping_gaussian)
+pl.axis('off')
+pl.title('MappingTransport (gaussian)')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_mapping_colors_images.rst b/docs/source/auto_examples/plot_otda_mapping_colors_images.rst
new file mode 100644
index 0000000..2afdc8a
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping_colors_images.rst
@@ -0,0 +1,310 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_mapping_colors_images.py:
+
+
+=====================================================
+OT for image color adaptation with mapping estimation
+=====================================================
+
+OT for domain adaptation with image color adaptation [6] with mapping
+estimation [8].
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized
+ discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3),
+ 1853-1882.
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
+ discrete optimal transport", Neural Information Processing Systems (NIPS),
+ 2016.
+
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import numpy as np
+ from scipy import ndimage
+ import matplotlib.pylab as pl
+ import ot
+
+ r = np.random.RandomState(42)
+
+
+ def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+ def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+ def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ # Loading images
+ I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+ I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+
+ X1 = im2mat(I1)
+ X2 = im2mat(I2)
+
+ # training samples
+ nb = 1000
+ idx1 = r.randint(X1.shape[0], size=(nb,))
+ idx2 = r.randint(X2.shape[0], size=(nb,))
+
+ Xs = X1[idx1, :]
+ Xt = X2[idx2, :]
+
+
+
+
+
+
+
+
+Domain adaptation for pixel distribution transfer
+-------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # EMDTransport
+ ot_emd = ot.da.EMDTransport()
+ ot_emd.fit(Xs=Xs, Xt=Xt)
+ transp_Xs_emd = ot_emd.transform(Xs=X1)
+ Image_emd = minmax(mat2im(transp_Xs_emd, I1.shape))
+
+ # SinkhornTransport
+ ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+ transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)
+ Image_sinkhorn = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+
+ ot_mapping_linear = ot.da.MappingTransport(
+ mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True)
+ ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+ X1tl = ot_mapping_linear.transform(Xs=X1)
+ Image_mapping_linear = minmax(mat2im(X1tl, I1.shape))
+
+ ot_mapping_gaussian = ot.da.MappingTransport(
+ mu=1e0, eta=1e-2, sigma=1, bias=False, max_iter=10, verbose=True)
+ ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+ X1tn = ot_mapping_gaussian.transform(Xs=X1) # use the estimated mapping
+ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
+
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|3.680534e+02|0.000000e+00
+ 1|3.592501e+02|-2.391854e-02
+ 2|3.590682e+02|-5.061555e-04
+ 3|3.589745e+02|-2.610227e-04
+ 4|3.589167e+02|-1.611644e-04
+ 5|3.588768e+02|-1.109242e-04
+ 6|3.588482e+02|-7.972733e-05
+ 7|3.588261e+02|-6.166174e-05
+ 8|3.588086e+02|-4.871697e-05
+ 9|3.587946e+02|-3.919056e-05
+ 10|3.587830e+02|-3.228124e-05
+ 11|3.587731e+02|-2.744744e-05
+ 12|3.587648e+02|-2.334451e-05
+ 13|3.587576e+02|-1.995629e-05
+ 14|3.587513e+02|-1.761058e-05
+ 15|3.587457e+02|-1.542568e-05
+ 16|3.587408e+02|-1.366315e-05
+ 17|3.587365e+02|-1.221732e-05
+ 18|3.587325e+02|-1.102488e-05
+ 19|3.587303e+02|-6.062107e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|3.784871e+02|0.000000e+00
+ 1|3.646491e+02|-3.656142e-02
+ 2|3.642975e+02|-9.642655e-04
+ 3|3.641626e+02|-3.702413e-04
+ 4|3.640888e+02|-2.026301e-04
+ 5|3.640419e+02|-1.289607e-04
+ 6|3.640097e+02|-8.831646e-05
+ 7|3.639861e+02|-6.487612e-05
+ 8|3.639679e+02|-4.994063e-05
+ 9|3.639536e+02|-3.941436e-05
+ 10|3.639419e+02|-3.209753e-05
+
+
+Plot original images
+--------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(6.4, 3))
+ pl.subplot(1, 2, 1)
+ pl.imshow(I1)
+ pl.axis('off')
+ pl.title('Image 1')
+
+ pl.subplot(1, 2, 2)
+ pl.imshow(I2)
+ pl.axis('off')
+ pl.title('Image 2')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.png
+ :align: center
+
+
+
+
+Plot pixel values distribution
+------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2, figsize=(6.4, 5))
+
+ pl.subplot(1, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+ pl.axis([0, 1, 0, 1])
+ pl.xlabel('Red')
+ pl.ylabel('Blue')
+ pl.title('Image 1')
+
+ pl.subplot(1, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+ pl.axis([0, 1, 0, 1])
+ pl.xlabel('Red')
+ pl.ylabel('Blue')
+ pl.title('Image 2')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.png
+ :align: center
+
+
+
+
+Plot transformed images
+-----------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2, figsize=(10, 5))
+
+ pl.subplot(2, 3, 1)
+ pl.imshow(I1)
+ pl.axis('off')
+ pl.title('Im. 1')
+
+ pl.subplot(2, 3, 4)
+ pl.imshow(I2)
+ pl.axis('off')
+ pl.title('Im. 2')
+
+ pl.subplot(2, 3, 2)
+ pl.imshow(Image_emd)
+ pl.axis('off')
+ pl.title('EmdTransport')
+
+ pl.subplot(2, 3, 5)
+ pl.imshow(Image_sinkhorn)
+ pl.axis('off')
+ pl.title('SinkhornTransport')
+
+ pl.subplot(2, 3, 3)
+ pl.imshow(Image_mapping_linear)
+ pl.axis('off')
+ pl.title('MappingTransport (linear)')
+
+ pl.subplot(2, 3, 6)
+ pl.imshow(Image_mapping_gaussian)
+ pl.axis('off')
+ pl.title('MappingTransport (gaussian)')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 3 minutes 14.206 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_mapping_colors_images.py <plot_otda_mapping_colors_images.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_mapping_colors_images.ipynb <plot_otda_mapping_colors_images.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_semi_supervised.ipynb b/docs/source/auto_examples/plot_otda_semi_supervised.ipynb
new file mode 100644
index 0000000..e3192da
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_semi_supervised.ipynb
@@ -0,0 +1,144 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# OTDA unsupervised vs semi-supervised setting\n\n\nThis example introduces a semi supervised domain adaptation in a 2D setting.\nIt explicits the problem of semi supervised domain adaptation and introduces\nsome optimal transport approaches to solve it.\n\nQuantities such as optimal couplings, greater coupling coefficients and\ntransported samples are represented in order to give a visual understanding\nof what the transport methods are doing.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport matplotlib.pylab as pl\nimport ot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n_samples_source = 150\nn_samples_target = 150\n\nXs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source)\nXt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Transport source samples onto target samples\n--------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# unsupervised domain adaptation\not_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn_un.fit(Xs=Xs, Xt=Xt)\ntransp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)\n\n# semi-supervised domain adaptation\not_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)\ntransp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)\n\n# semi supervised DA uses available labaled target samples to modify the cost\n# matrix involved in the OT problem. The cost of transporting a source sample\n# of class A onto a target sample of class B != A is set to infinite, or a\n# very large value\n\n# note that in the present case we consider that all the target samples are\n# labeled. For daily applications, some target sample might not have labels,\n# in this case the element of yt corresponding to these samples should be\n# filled with -1.\n\n# Warning: we recall that -1 cannot be used as a class label"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 1 : plots source and target samples + matrix of pairwise distance\n---------------------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1, figsize=(10, 10))\npl.subplot(2, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Source samples')\n\npl.subplot(2, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Target samples')\n\npl.subplot(2, 2, 3)\npl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Cost matrix - unsupervised DA')\n\npl.subplot(2, 2, 4)\npl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Cost matrix - semisupervised DA')\n\npl.tight_layout()\n\n# the optimal coupling in the semi-supervised DA case will exhibit \" shape\n# similar\" to the cost matrix, (block diagonal matrix)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 2 : plots optimal couplings for the different methods\n---------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(2, figsize=(8, 4))\n\npl.subplot(1, 2, 1)\npl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nUnsupervised DA')\n\npl.subplot(1, 2, 2)\npl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSemi-supervised DA')\n\npl.tight_layout()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 3 : plot transported samples\n--------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# display transported samples\npl.figure(4, figsize=(8, 4))\npl.subplot(1, 2, 1)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nEmdTransport')\npl.legend(loc=0)\npl.xticks([])\npl.yticks([])\n\npl.subplot(1, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nSinkhornTransport')\npl.xticks([])\npl.yticks([])\n\npl.tight_layout()\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_semi_supervised.py b/docs/source/auto_examples/plot_otda_semi_supervised.py
new file mode 100644
index 0000000..8a67720
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_semi_supervised.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+"""
+============================================
+OTDA unsupervised vs semi-supervised setting
+============================================
+
+This example introduces a semi supervised domain adaptation in a 2D setting.
+It explicits the problem of semi supervised domain adaptation and introduces
+some optimal transport approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+
+n_samples_source = 150
+n_samples_target = 150
+
+Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source)
+Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target)
+
+
+##############################################################################
+# Transport source samples onto target samples
+# --------------------------------------------
+
+
+# unsupervised domain adaptation
+ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
+transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
+
+# semi-supervised domain adaptation
+ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
+transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
+
+# semi supervised DA uses available labaled target samples to modify the cost
+# matrix involved in the OT problem. The cost of transporting a source sample
+# of class A onto a target sample of class B != A is set to infinite, or a
+# very large value
+
+# note that in the present case we consider that all the target samples are
+# labeled. For daily applications, some target sample might not have labels,
+# in this case the element of yt corresponding to these samples should be
+# filled with -1.
+
+# Warning: we recall that -1 cannot be used as a class label
+
+
+##############################################################################
+# Fig 1 : plots source and target samples + matrix of pairwise distance
+# ---------------------------------------------------------------------
+
+pl.figure(1, figsize=(10, 10))
+pl.subplot(2, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+
+pl.subplot(2, 2, 3)
+pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - unsupervised DA')
+
+pl.subplot(2, 2, 4)
+pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - semisupervised DA')
+
+pl.tight_layout()
+
+# the optimal coupling in the semi-supervised DA case will exhibit " shape
+# similar" to the cost matrix, (block diagonal matrix)
+
+
+##############################################################################
+# Fig 2 : plots optimal couplings for the different methods
+# ---------------------------------------------------------
+
+pl.figure(2, figsize=(8, 4))
+
+pl.subplot(1, 2, 1)
+pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nUnsupervised DA')
+
+pl.subplot(1, 2, 2)
+pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSemi-supervised DA')
+
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 3 : plot transported samples
+# --------------------------------
+
+# display transported samples
+pl.figure(4, figsize=(8, 4))
+pl.subplot(1, 2, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc=0)
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornTransport')
+pl.xticks([])
+pl.yticks([])
+
+pl.tight_layout()
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_semi_supervised.rst b/docs/source/auto_examples/plot_otda_semi_supervised.rst
new file mode 100644
index 0000000..2ed7819
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_semi_supervised.rst
@@ -0,0 +1,245 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_semi_supervised.py:
+
+
+============================================
+OTDA unsupervised vs semi-supervised setting
+============================================
+
+This example introduces a semi supervised domain adaptation in a 2D setting.
+It explicits the problem of semi supervised domain adaptation and introduces
+some optimal transport approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ n_samples_source = 150
+ n_samples_target = 150
+
+ Xs, ys = ot.datasets.make_data_classif('3gauss', n_samples_source)
+ Xt, yt = ot.datasets.make_data_classif('3gauss2', n_samples_target)
+
+
+
+
+
+
+
+
+Transport source samples onto target samples
+--------------------------------------------
+
+
+
+.. code-block:: python
+
+
+
+ # unsupervised domain adaptation
+ ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
+ transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
+
+ # semi-supervised domain adaptation
+ ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
+ transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
+
+ # semi supervised DA uses available labaled target samples to modify the cost
+ # matrix involved in the OT problem. The cost of transporting a source sample
+ # of class A onto a target sample of class B != A is set to infinite, or a
+ # very large value
+
+ # note that in the present case we consider that all the target samples are
+ # labeled. For daily applications, some target sample might not have labels,
+ # in this case the element of yt corresponding to these samples should be
+ # filled with -1.
+
+ # Warning: we recall that -1 cannot be used as a class label
+
+
+
+
+
+
+
+
+Fig 1 : plots source and target samples + matrix of pairwise distance
+---------------------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(10, 10))
+ pl.subplot(2, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Source samples')
+
+ pl.subplot(2, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Target samples')
+
+ pl.subplot(2, 2, 3)
+ pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Cost matrix - unsupervised DA')
+
+ pl.subplot(2, 2, 4)
+ pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Cost matrix - semisupervised DA')
+
+ pl.tight_layout()
+
+ # the optimal coupling in the semi-supervised DA case will exhibit " shape
+ # similar" to the cost matrix, (block diagonal matrix)
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.png
+ :align: center
+
+
+
+
+Fig 2 : plots optimal couplings for the different methods
+---------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2, figsize=(8, 4))
+
+ pl.subplot(1, 2, 1)
+ pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nUnsupervised DA')
+
+ pl.subplot(1, 2, 2)
+ pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSemi-supervised DA')
+
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.png
+ :align: center
+
+
+
+
+Fig 3 : plot transported samples
+--------------------------------
+
+
+
+.. code-block:: python
+
+
+ # display transported samples
+ pl.figure(4, figsize=(8, 4))
+ pl.subplot(1, 2, 1)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nEmdTransport')
+ pl.legend(loc=0)
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.subplot(1, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nSinkhornTransport')
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.tight_layout()
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.256 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_semi_supervised.py <plot_otda_semi_supervised.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_semi_supervised.ipynb <plot_otda_semi_supervised.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_stochastic.ipynb b/docs/source/auto_examples/plot_stochastic.ipynb
new file mode 100644
index 0000000..7f6ff3d
--- /dev/null
+++ b/docs/source/auto_examples/plot_stochastic.ipynb
@@ -0,0 +1,295 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# Stochastic examples\n\n\nThis example is designed to show how to use the stochatic optimization\nalgorithms for descrete and semicontinous measures from the POT library.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Author: Kilian Fatras <kilian.fatras@gmail.com>\n#\n# License: MIT License\n\nimport matplotlib.pylab as pl\nimport numpy as np\nimport ot\nimport ot.plot"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM\n############################################################################\n############################################################################\n DISCRETE CASE:\n\n Sample two discrete measures for the discrete case\n ---------------------------------------------\n\n Define 2 discrete measures a and b, the points where are defined the source\n and the target measures and finally the cost matrix c.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n_source = 7\nn_target = 4\nreg = 1\nnumItermax = 1000\n\na = ot.utils.unif(n_source)\nb = ot.utils.unif(n_target)\n\nrng = np.random.RandomState(0)\nX_source = rng.randn(n_source, 2)\nY_target = rng.randn(n_target, 2)\nM = ot.dist(X_source, Y_target)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Call the \"SAG\" method to find the transportation matrix in the discrete case\n---------------------------------------------\n\nDefine the method \"SAG\", call ot.solve_semi_dual_entropic and plot the\nresults.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "method = \"SAG\"\nsag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,\n numItermax)\nprint(sag_pi)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "SEMICONTINOUS CASE:\n\nSample one general measure a, one discrete measures b for the semicontinous\ncase\n---------------------------------------------\n\nDefine one general measure a, one discrete measures b, the points where\nare defined the source and the target measures and finally the cost matrix c.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n_source = 7\nn_target = 4\nreg = 1\nnumItermax = 1000\nlog = True\n\na = ot.utils.unif(n_source)\nb = ot.utils.unif(n_target)\n\nrng = np.random.RandomState(0)\nX_source = rng.randn(n_source, 2)\nY_target = rng.randn(n_target, 2)\nM = ot.dist(X_source, Y_target)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Call the \"ASGD\" method to find the transportation matrix in the semicontinous\ncase\n---------------------------------------------\n\nDefine the method \"ASGD\", call ot.solve_semi_dual_entropic and plot the\nresults.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "method = \"ASGD\"\nasgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,\n numItermax, log=log)\nprint(log_asgd['alpha'], log_asgd['beta'])\nprint(asgd_pi)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compare the results with the Sinkhorn algorithm\n---------------------------------------------\n\nCall the Sinkhorn algorithm from POT\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "sinkhorn_pi = ot.sinkhorn(a, b, M, reg)\nprint(sinkhorn_pi)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "PLOT TRANSPORTATION MATRIX\n#############################################################################\n\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot SAG results\n----------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot ASGD results\n-----------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot Sinkhorn results\n---------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM\n############################################################################\n############################################################################\n SEMICONTINOUS CASE:\n\n Sample one general measure a, one discrete measures b for the semicontinous\n case\n ---------------------------------------------\n\n Define one general measure a, one discrete measures b, the points where\n are defined the source and the target measures and finally the cost matrix c.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n_source = 7\nn_target = 4\nreg = 1\nnumItermax = 100000\nlr = 0.1\nbatch_size = 3\nlog = True\n\na = ot.utils.unif(n_source)\nb = ot.utils.unif(n_target)\n\nrng = np.random.RandomState(0)\nX_source = rng.randn(n_source, 2)\nY_target = rng.randn(n_target, 2)\nM = ot.dist(X_source, Y_target)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Call the \"SGD\" dual method to find the transportation matrix in the\nsemicontinous case\n---------------------------------------------\n\nCall ot.solve_dual_entropic and plot the results.\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,\n batch_size, numItermax,\n lr, log=log)\nprint(log_sgd['alpha'], log_sgd['beta'])\nprint(sgd_dual_pi)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Compare the results with the Sinkhorn algorithm\n---------------------------------------------\n\nCall the Sinkhorn algorithm from POT\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "sinkhorn_pi = ot.sinkhorn(a, b, M, reg)\nprint(sinkhorn_pi)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot SGD results\n-----------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Plot Sinkhorn results\n---------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_stochastic.py b/docs/source/auto_examples/plot_stochastic.py
new file mode 100644
index 0000000..742f8d9
--- /dev/null
+++ b/docs/source/auto_examples/plot_stochastic.py
@@ -0,0 +1,208 @@
+"""
+==========================
+Stochastic examples
+==========================
+
+This example is designed to show how to use the stochatic optimization
+algorithms for descrete and semicontinous measures from the POT library.
+
+"""
+
+# Author: Kilian Fatras <kilian.fatras@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import numpy as np
+import ot
+import ot.plot
+
+
+#############################################################################
+# COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
+#############################################################################
+#############################################################################
+# DISCRETE CASE:
+#
+# Sample two discrete measures for the discrete case
+# ---------------------------------------------
+#
+# Define 2 discrete measures a and b, the points where are defined the source
+# and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 1000
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "SAG" method to find the transportation matrix in the discrete case
+# ---------------------------------------------
+#
+# Define the method "SAG", call ot.solve_semi_dual_entropic and plot the
+# results.
+
+method = "SAG"
+sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax)
+print(sag_pi)
+
+#############################################################################
+# SEMICONTINOUS CASE:
+#
+# Sample one general measure a, one discrete measures b for the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define one general measure a, one discrete measures b, the points where
+# are defined the source and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 1000
+log = True
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "ASGD" method to find the transportation matrix in the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the
+# results.
+
+method = "ASGD"
+asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax, log=log)
+print(log_asgd['alpha'], log_asgd['beta'])
+print(asgd_pi)
+
+#############################################################################
+#
+# Compare the results with the Sinkhorn algorithm
+# ---------------------------------------------
+#
+# Call the Sinkhorn algorithm from POT
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+print(sinkhorn_pi)
+
+
+##############################################################################
+# PLOT TRANSPORTATION MATRIX
+##############################################################################
+
+##############################################################################
+# Plot SAG results
+# ----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
+pl.show()
+
+
+##############################################################################
+# Plot ASGD results
+# -----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
+pl.show()
+
+
+##############################################################################
+# Plot Sinkhorn results
+# ---------------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()
+
+
+#############################################################################
+# COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
+#############################################################################
+#############################################################################
+# SEMICONTINOUS CASE:
+#
+# Sample one general measure a, one discrete measures b for the semicontinous
+# case
+# ---------------------------------------------
+#
+# Define one general measure a, one discrete measures b, the points where
+# are defined the source and the target measures and finally the cost matrix c.
+
+n_source = 7
+n_target = 4
+reg = 1
+numItermax = 100000
+lr = 0.1
+batch_size = 3
+log = True
+
+a = ot.utils.unif(n_source)
+b = ot.utils.unif(n_target)
+
+rng = np.random.RandomState(0)
+X_source = rng.randn(n_source, 2)
+Y_target = rng.randn(n_target, 2)
+M = ot.dist(X_source, Y_target)
+
+#############################################################################
+#
+# Call the "SGD" dual method to find the transportation matrix in the
+# semicontinous case
+# ---------------------------------------------
+#
+# Call ot.solve_dual_entropic and plot the results.
+
+sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
+ batch_size, numItermax,
+ lr, log=log)
+print(log_sgd['alpha'], log_sgd['beta'])
+print(sgd_dual_pi)
+
+#############################################################################
+#
+# Compare the results with the Sinkhorn algorithm
+# ---------------------------------------------
+#
+# Call the Sinkhorn algorithm from POT
+
+sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+print(sinkhorn_pi)
+
+##############################################################################
+# Plot SGD results
+# -----------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
+pl.show()
+
+
+##############################################################################
+# Plot Sinkhorn results
+# ---------------------
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+pl.show()
diff --git a/docs/source/auto_examples/plot_stochastic.rst b/docs/source/auto_examples/plot_stochastic.rst
new file mode 100644
index 0000000..d531045
--- /dev/null
+++ b/docs/source/auto_examples/plot_stochastic.rst
@@ -0,0 +1,446 @@
+
+
+.. _sphx_glr_auto_examples_plot_stochastic.py:
+
+
+==========================
+Stochastic examples
+==========================
+
+This example is designed to show how to use the stochatic optimization
+algorithms for descrete and semicontinous measures from the POT library.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Kilian Fatras <kilian.fatras@gmail.com>
+ #
+ # License: MIT License
+
+ import matplotlib.pylab as pl
+ import numpy as np
+ import ot
+ import ot.plot
+
+
+
+
+
+
+
+
+COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM
+############################################################################
+############################################################################
+ DISCRETE CASE:
+
+ Sample two discrete measures for the discrete case
+ ---------------------------------------------
+
+ Define 2 discrete measures a and b, the points where are defined the source
+ and the target measures and finally the cost matrix c.
+
+
+
+.. code-block:: python
+
+
+ n_source = 7
+ n_target = 4
+ reg = 1
+ numItermax = 1000
+
+ a = ot.utils.unif(n_source)
+ b = ot.utils.unif(n_target)
+
+ rng = np.random.RandomState(0)
+ X_source = rng.randn(n_source, 2)
+ Y_target = rng.randn(n_target, 2)
+ M = ot.dist(X_source, Y_target)
+
+
+
+
+
+
+
+Call the "SAG" method to find the transportation matrix in the discrete case
+---------------------------------------------
+
+Define the method "SAG", call ot.solve_semi_dual_entropic and plot the
+results.
+
+
+
+.. code-block:: python
+
+
+ method = "SAG"
+ sag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax)
+ print(sag_pi)
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ [[2.55553509e-02 9.96395660e-02 1.76579142e-02 4.31178196e-06]
+ [1.21640234e-01 1.25357448e-02 1.30225078e-03 7.37891338e-03]
+ [3.56123975e-03 7.61451746e-02 6.31505947e-02 1.33831456e-07]
+ [2.61515202e-02 3.34246014e-02 8.28734709e-02 4.07550428e-04]
+ [9.85500870e-03 7.52288517e-04 1.08262628e-02 1.21423583e-01]
+ [2.16904253e-02 9.03825797e-04 1.87178503e-03 1.18391107e-01]
+ [4.15462212e-02 2.65987989e-02 7.23177216e-02 2.39440107e-03]]
+
+
+SEMICONTINOUS CASE:
+
+Sample one general measure a, one discrete measures b for the semicontinous
+case
+---------------------------------------------
+
+Define one general measure a, one discrete measures b, the points where
+are defined the source and the target measures and finally the cost matrix c.
+
+
+
+.. code-block:: python
+
+
+ n_source = 7
+ n_target = 4
+ reg = 1
+ numItermax = 1000
+ log = True
+
+ a = ot.utils.unif(n_source)
+ b = ot.utils.unif(n_target)
+
+ rng = np.random.RandomState(0)
+ X_source = rng.randn(n_source, 2)
+ Y_target = rng.randn(n_target, 2)
+ M = ot.dist(X_source, Y_target)
+
+
+
+
+
+
+
+Call the "ASGD" method to find the transportation matrix in the semicontinous
+case
+---------------------------------------------
+
+Define the method "ASGD", call ot.solve_semi_dual_entropic and plot the
+results.
+
+
+
+.. code-block:: python
+
+
+ method = "ASGD"
+ asgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,
+ numItermax, log=log)
+ print(log_asgd['alpha'], log_asgd['beta'])
+ print(asgd_pi)
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ [3.98220325 7.76235856 3.97645524 2.72051681 1.23219313 3.07696856
+ 2.84476972] [-2.65544161 -2.50838395 -0.9397765 6.10360206]
+ [[2.34528761e-02 1.00491956e-01 1.89058354e-02 6.47543413e-06]
+ [1.16616747e-01 1.32074516e-02 1.45653361e-03 1.15764107e-02]
+ [3.16154850e-03 7.42892944e-02 6.54061055e-02 1.94426150e-07]
+ [2.33152216e-02 3.27486992e-02 8.61986263e-02 5.94595747e-04]
+ [6.34131496e-03 5.31975896e-04 8.12724003e-03 1.27856612e-01]
+ [1.41744829e-02 6.49096245e-04 1.42704389e-03 1.26606520e-01]
+ [3.73127657e-02 2.62526499e-02 7.57727161e-02 3.51901117e-03]]
+
+
+Compare the results with the Sinkhorn algorithm
+---------------------------------------------
+
+Call the Sinkhorn algorithm from POT
+
+
+
+.. code-block:: python
+
+
+ sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+ print(sinkhorn_pi)
+
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ [[2.55535622e-02 9.96413843e-02 1.76578860e-02 4.31043335e-06]
+ [1.21640742e-01 1.25369034e-02 1.30234529e-03 7.37715259e-03]
+ [3.56096458e-03 7.61460101e-02 6.31500344e-02 1.33788624e-07]
+ [2.61499607e-02 3.34255577e-02 8.28741973e-02 4.07427179e-04]
+ [9.85698720e-03 7.52505948e-04 1.08291770e-02 1.21418473e-01]
+ [2.16947591e-02 9.04086158e-04 1.87228707e-03 1.18386011e-01]
+ [4.15442692e-02 2.65998963e-02 7.23192701e-02 2.39370724e-03]]
+
+
+PLOT TRANSPORTATION MATRIX
+#############################################################################
+
+
+Plot SAG results
+----------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')
+ pl.show()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_stochastic_004.png
+ :align: center
+
+
+
+
+Plot ASGD results
+-----------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')
+ pl.show()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_stochastic_005.png
+ :align: center
+
+
+
+
+Plot Sinkhorn results
+---------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+ pl.show()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_stochastic_006.png
+ :align: center
+
+
+
+
+COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM
+############################################################################
+############################################################################
+ SEMICONTINOUS CASE:
+
+ Sample one general measure a, one discrete measures b for the semicontinous
+ case
+ ---------------------------------------------
+
+ Define one general measure a, one discrete measures b, the points where
+ are defined the source and the target measures and finally the cost matrix c.
+
+
+
+.. code-block:: python
+
+
+ n_source = 7
+ n_target = 4
+ reg = 1
+ numItermax = 100000
+ lr = 0.1
+ batch_size = 3
+ log = True
+
+ a = ot.utils.unif(n_source)
+ b = ot.utils.unif(n_target)
+
+ rng = np.random.RandomState(0)
+ X_source = rng.randn(n_source, 2)
+ Y_target = rng.randn(n_target, 2)
+ M = ot.dist(X_source, Y_target)
+
+
+
+
+
+
+
+Call the "SGD" dual method to find the transportation matrix in the
+semicontinous case
+---------------------------------------------
+
+Call ot.solve_dual_entropic and plot the results.
+
+
+
+.. code-block:: python
+
+
+ sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,
+ batch_size, numItermax,
+ lr, log=log)
+ print(log_sgd['alpha'], log_sgd['beta'])
+ print(sgd_dual_pi)
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ [0.92449986 2.75486107 1.07923806 0.02741145 0.61355413 1.81961594
+ 0.12072562] [0.33831611 0.46806842 1.5640451 4.96947652]
+ [[2.20001105e-02 9.26497883e-02 1.08654588e-02 9.78995555e-08]
+ [1.55669974e-02 1.73279561e-03 1.19120878e-04 2.49058251e-05]
+ [3.48198483e-03 8.04151063e-02 4.41335396e-02 3.45115752e-09]
+ [3.14927954e-02 4.34760520e-02 7.13338154e-02 1.29442395e-05]
+ [6.81836550e-02 5.62182457e-03 5.35386584e-02 2.21568095e-02]
+ [8.04671052e-02 3.62163462e-03 4.96331605e-03 1.15837801e-02]
+ [4.88644009e-02 3.37903481e-02 6.07955004e-02 7.42743505e-05]]
+
+
+Compare the results with the Sinkhorn algorithm
+---------------------------------------------
+
+Call the Sinkhorn algorithm from POT
+
+
+
+.. code-block:: python
+
+
+ sinkhorn_pi = ot.sinkhorn(a, b, M, reg)
+ print(sinkhorn_pi)
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ [[2.55535622e-02 9.96413843e-02 1.76578860e-02 4.31043335e-06]
+ [1.21640742e-01 1.25369034e-02 1.30234529e-03 7.37715259e-03]
+ [3.56096458e-03 7.61460101e-02 6.31500344e-02 1.33788624e-07]
+ [2.61499607e-02 3.34255577e-02 8.28741973e-02 4.07427179e-04]
+ [9.85698720e-03 7.52505948e-04 1.08291770e-02 1.21418473e-01]
+ [2.16947591e-02 9.04086158e-04 1.87228707e-03 1.18386011e-01]
+ [4.15442692e-02 2.65998963e-02 7.23192701e-02 2.39370724e-03]]
+
+
+Plot SGD results
+-----------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')
+ pl.show()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_stochastic_007.png
+ :align: center
+
+
+
+
+Plot Sinkhorn results
+---------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_stochastic_008.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 20.889 seconds)
+
+
+
+.. only :: html
+
+ .. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_stochastic.py <plot_stochastic.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_stochastic.ipynb <plot_stochastic.ipynb>`
+
+
+.. only:: html
+
+ .. rst-class:: sphx-glr-signature
+
+ `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/searchindex b/docs/source/auto_examples/searchindex
new file mode 100644
index 0000000..2cad500
--- /dev/null
+++ b/docs/source/auto_examples/searchindex
Binary files differ
diff --git a/docs/source/conf.py b/docs/source/conf.py
new file mode 100644
index 0000000..d29b829
--- /dev/null
+++ b/docs/source/conf.py
@@ -0,0 +1,343 @@
+# -*- coding: utf-8 -*-
+#
+# POT documentation build configuration file, created by
+# sphinx-quickstart on Mon Oct 24 11:10:10 2016.
+#
+# This file is execfile()d with the current directory set to its
+# containing dir.
+#
+# Note that not all possible configuration values are present in this
+# autogenerated file.
+#
+# All configuration values have a default; values that are commented out
+# serve to show the default.
+
+import sys
+import os
+import re
+try:
+ import sphinx_gallery
+except ImportError:
+ print("warning sphinx-gallery not installed")
+
+# !!!! allow readthedoc compilation
+try:
+ from unittest.mock import MagicMock
+except ImportError:
+ from mock import Mock as MagicMock
+ ## check whether in the source directory...
+#
+
+
+#!!! This should be commented when executing sphinx-gallery
+class Mock(MagicMock):
+ @classmethod
+ def __getattr__(cls, name):
+ return MagicMock()
+MOCK_MODULES = ['ot.lp.emd_wrap','autograd','pymanopt','cupy','autograd.numpy','pymanopt.manifolds','pymanopt.solvers']
+# 'autograd.numpy','pymanopt.manifolds','pymanopt.solvers',
+sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
+# !!!!
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+sys.path.insert(0, os.path.abspath('.'))
+sys.path.insert(0, os.path.abspath('..'))
+sys.path.insert(0, os.path.abspath("../.."))
+
+
+# -- General configuration ------------------------------------------------
+
+# If your documentation needs a minimal Sphinx version, state it here.
+#needs_sphinx = '1.0'
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named #'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.doctest',
+ 'sphinx.ext.intersphinx',
+ 'sphinx.ext.todo',
+ 'sphinx.ext.coverage',
+ 'sphinx.ext.mathjax',
+ 'sphinx.ext.ifconfig',
+ 'sphinx.ext.viewcode',
+ 'sphinx.ext.napoleon',
+ #'sphinx_gallery.gen_gallery',
+]
+
+napoleon_numpy_docstring = True
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# The suffix(es) of source filenames.
+# You can specify multiple suffix as a list of string:
+# source_suffix = ['.rst', '.md']
+source_suffix = '.rst'
+
+# The encoding of source files.
+source_encoding = 'utf-8-sig'
+
+# The master toctree document.
+master_doc = 'index'
+
+# General information about the project.
+project = u'POT Python Optimal Transport'
+copyright = u'2016-2019, Rémi Flamary, Nicolas Courty'
+author = u'Rémi Flamary, Nicolas Courty'
+
+# The version info for the project you're documenting, acts as replacement for
+# |version| and |release|, also used in various other places throughout the
+# built documents.
+#
+
+__version__ = re.search(
+ r'__version__\s*=\s*[\'"]([^\'"]*)[\'"]', # It excludes inline comment too
+ open('../../ot/__init__.py').read()).group(1)
+# The short X.Y version.
+version = __version__
+# The full version, including alpha/beta/rc tags.
+release = __version__
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = None
+
+# There are two options for replacing |today|: either, you set today to some
+# non-false value, then it is used:
+#today = ''
+# Else, today_fmt is used as the format for a strftime call.
+#today_fmt = '%B %d, %Y'
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+exclude_patterns = []
+
+# The reST default role (used for this markup: `text`) to use for all
+# documents.
+#default_role = None
+
+# If true, '()' will be appended to :func: etc. cross-reference text.
+#add_function_parentheses = True
+
+# If true, the current module name will be prepended to all description
+# unit titles (such as .. function::).
+#add_module_names = True
+
+# If true, sectionauthor and moduleauthor directives will be shown in the
+# output. They are ignored by default.
+#show_authors = False
+
+# The name of the Pygments (syntax highlighting) style to use.
+pygments_style = 'sphinx'
+
+# A list of ignored prefixes for module index sorting.
+#modindex_common_prefix = []
+
+# If true, keep warnings as "system message" paragraphs in the built documents.
+#keep_warnings = False
+
+# If true, `todo` and `todoList` produce output, else they produce nothing.
+todo_include_todos = True
+
+
+# -- Options for HTML output ----------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+html_theme = 'sphinx_rtd_theme'
+
+# Theme options are theme-specific and customize the look and feel of a theme
+# further. For a list of options available for each theme, see the
+# documentation.
+#html_theme_options = {}
+
+# Add any paths that contain custom themes here, relative to this directory.
+#html_theme_path = []
+
+# The name for this set of Sphinx documents. If None, it defaults to
+# "<project> v<release> documentation".
+#html_title = None
+
+# A shorter title for the navigation bar. Default is the same as html_title.
+#html_short_title = None
+
+# The name of an image file (relative to this directory) to place at the top
+# of the sidebar.
+#html_logo = None
+
+# The name of an image file (relative to this directory) to use as a favicon of
+# the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
+# pixels large.
+#html_favicon = None
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
+
+# Add any extra paths that contain custom files (such as robots.txt or
+# .htaccess) here, relative to this directory. These files are copied
+# directly to the root of the documentation.
+#html_extra_path = []
+
+# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
+# using the given strftime format.
+#html_last_updated_fmt = '%b %d, %Y'
+
+# If true, SmartyPants will be used to convert quotes and dashes to
+# typographically correct entities.
+#html_use_smartypants = True
+
+# Custom sidebar templates, maps document names to template names.
+#html_sidebars = {}
+
+# Additional templates that should be rendered to pages, maps page names to
+# template names.
+#html_additional_pages = {}
+
+# If false, no module index is generated.
+#html_domain_indices = True
+
+# If false, no index is generated.
+#html_use_index = True
+
+# If true, the index is split into individual pages for each letter.
+#html_split_index = False
+
+# If true, links to the reST sources are added to the pages.
+#html_show_sourcelink = True
+
+# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
+#html_show_sphinx = True
+
+# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
+#html_show_copyright = True
+
+# If true, an OpenSearch description file will be output, and all pages will
+# contain a <link> tag referring to it. The value of this option must be the
+# base URL from which the finished HTML is served.
+#html_use_opensearch = ''
+
+# This is the file name suffix for HTML files (e.g. ".xhtml").
+#html_file_suffix = None
+
+# Language to be used for generating the HTML full-text search index.
+# Sphinx supports the following languages:
+# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja'
+# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr'
+#html_search_language = 'en'
+
+# A dictionary with options for the search language support, empty by default.
+# Now only 'ja' uses this config value
+#html_search_options = {'type': 'default'}
+
+# The name of a javascript file (relative to the configuration directory) that
+# implements a search results scorer. If empty, the default will be used.
+#html_search_scorer = 'scorer.js'
+
+# Output file base name for HTML help builder.
+htmlhelp_basename = 'POTdoc'
+
+# -- Options for LaTeX output ---------------------------------------------
+
+latex_elements = {
+# The paper size ('letterpaper' or 'a4paper').
+#'papersize': 'letterpaper',
+
+# The font size ('10pt', '11pt' or '12pt').
+#'pointsize': '10pt',
+
+# Additional stuff for the LaTeX preamble.
+#'preamble': '',
+
+# Latex figure (float) alignment
+#'figure_align': 'htbp',
+}
+
+# Grouping the document tree into LaTeX files. List of tuples
+# (source start file, target name, title,
+# author, documentclass [howto, manual, or own class]).
+latex_documents = [
+ (master_doc, 'POT.tex', u'POT Python Optimal Transport library',
+ author, 'manual'),
+]
+
+# The name of an image file (relative to this directory) to place at the top of
+# the title page.
+#latex_logo = None
+
+# For "manual" documents, if this is true, then toplevel headings are parts,
+# not chapters.
+#latex_use_parts = False
+
+# If true, show page references after internal links.
+#latex_show_pagerefs = False
+
+# If true, show URL addresses after external links.
+#latex_show_urls = False
+
+# Documents to append as an appendix to all manuals.
+#latex_appendices = []
+
+# If false, no module index is generated.
+#latex_domain_indices = True
+
+
+# -- Options for manual page output ---------------------------------------
+
+# One entry per manual page. List of tuples
+# (source start file, name, description, authors, manual section).
+man_pages = [
+ (master_doc, 'pot', u'POT Python Optimal Transport library Documentation',
+ [author], 1)
+]
+
+# If true, show URL addresses after external links.
+#man_show_urls = False
+
+
+# -- Options for Texinfo output -------------------------------------------
+
+# Grouping the document tree into Texinfo files. List of tuples
+# (source start file, target name, title, author,
+# dir menu entry, description, category)
+texinfo_documents = [
+ (master_doc, 'POT', u'POT Python Optimal Transport library Documentation',
+ author, 'POT', 'Python Optimal Transport librar.',
+ 'Miscellaneous'),
+]
+
+# Documents to append as an appendix to all manuals.
+#texinfo_appendices = []
+
+# If false, no module index is generated.
+#texinfo_domain_indices = True
+
+# How to display URL addresses: 'footnote', 'no', or 'inline'.
+#texinfo_show_urls = 'footnote'
+
+# If true, do not generate a @detailmenu in the "Top" node's menu.
+#texinfo_no_detailmenu = False
+
+
+# Example configuration for intersphinx: refer to the Python standard library.
+intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
+ 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
+ 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None),
+ 'matplotlib': ('http://matplotlib.sourceforge.net/', None)}
+
+sphinx_gallery_conf = {
+ 'examples_dirs': ['../../examples','../../examples/da'],
+ 'gallery_dirs': 'auto_examples',
+ 'backreferences_dir': '../modules/generated/',
+ 'reference_url': {
+ 'numpy': 'http://docs.scipy.org/doc/numpy-1.9.1',
+ 'scipy': 'http://docs.scipy.org/doc/scipy-0.17.0/reference'}
+}
diff --git a/docs/source/index.rst b/docs/source/index.rst
new file mode 100644
index 0000000..9078d35
--- /dev/null
+++ b/docs/source/index.rst
@@ -0,0 +1,29 @@
+.. POT documentation master file, created by
+ sphinx-quickstart on Mon Oct 24 11:10:10 2016.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+POT: Python Optimal Transport
+=============================
+
+Contents
+--------
+
+.. toctree::
+ :maxdepth: 2
+
+ self
+ quickstart
+ all
+ auto_examples/index
+
+.. include:: readme.rst
+ :start-line: 5
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst
new file mode 100644
index 0000000..978eaff
--- /dev/null
+++ b/docs/source/quickstart.rst
@@ -0,0 +1,923 @@
+
+Quick start guide
+=================
+
+In the following we provide some pointers about which functions and classes
+to use for different problems related to optimal transport (OT) and machine
+learning. We refer when we can to concrete examples in the documentation that
+are also available as notebooks on the POT Github.
+
+This document is not a tutorial on numerical optimal transport. For this we strongly
+recommend to read the very nice book [15]_ .
+
+
+Optimal transport and Wasserstein distance
+------------------------------------------
+
+.. note::
+ In POT, most functions that solve OT or regularized OT problems have two
+ versions that return the OT matrix or the value of the optimal solution. For
+ instance :any:`ot.emd` return the OT matrix and :any:`ot.emd2` return the
+ Wassertsein distance. This approach has been implemented in practice for all
+ solvers that return an OT matrix (even Gromov-Wasserstsein)
+
+Solving optimal transport
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The optimal transport problem between discrete distributions is often expressed
+as
+
+.. math::
+ \gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j}
+
+ s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
+
+where :
+
+- :math:`M\in\mathbb{R}_+^{m\times n}` is the metric cost matrix defining the cost to move mass from bin :math:`a_i` to bin :math:`b_j`.
+- :math:`a` and :math:`b` are histograms on the simplex (positive, sum to 1) that represent the
+weights of each samples in the source an target distributions.
+
+Solving the linear program above can be done using the function :any:`ot.emd`
+that will return the optimal transport matrix :math:`\gamma^*`:
+
+.. code:: python
+
+ # a,b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T=ot.emd(a,b,M) # exact linear program
+
+The method implemented for solving the OT problem is the network simplex, it is
+implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the
+solver is quite efficient and uses sparsity of the solution.
+
+.. hint::
+ Examples of use for :any:`ot.emd` are available in :
+
+ - :any:`auto_examples/plot_OT_2D_samples`
+ - :any:`auto_examples/plot_OT_1D`
+ - :any:`auto_examples/plot_OT_L1_vs_L2`
+
+
+Computing Wasserstein distance
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The value of the OT solution is often more of interest than the OT matrix :
+
+.. math::
+ OT(a,b)=\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j}
+
+ s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
+
+
+It can computed from an already estimated OT matrix with
+:code:`np.sum(T*M)` or directly with the function :any:`ot.emd2`.
+
+.. code:: python
+
+ # a,b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ W=ot.emd2(a,b,M) # Wasserstein distance / EMD value
+
+Note that the well known `Wasserstein distance
+<https://en.wikipedia.org/wiki/Wasserstein_metric>`_ between distributions a and
+b is defined as
+
+
+ .. math::
+
+ W_p(a,b)=(\min_\gamma \sum_{i,j}\gamma_{i,j}\|x_i-y_j\|_p)^\frac{1}{p}
+
+ s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
+
+This means that if you want to compute the :math:`W_2` you need to compute the
+square root of :any:`ot.emd2` when providing
+:code:`M=ot.dist(xs,xt)` that use the squared euclidean distance by default. Computing
+the :math:`W_1` wasserstein distance can be done directly with :any:`ot.emd2`
+when providing :code:`M=ot.dist(xs,xt, metric='euclidean')` to use the euclidean
+distance.
+
+
+.. hint::
+ An example of use for :any:`ot.emd2` is available in :
+
+ - :any:`auto_examples/plot_compute_emd`
+
+
+Special cases
+^^^^^^^^^^^^^
+
+Note that the OT problem and the corresponding Wasserstein distance can in some
+special cases be computed very efficiently.
+
+For instance when the samples are in 1D, then the OT problem can be solved in
+:math:`O(n\log(n))` by using a simple sorting. In this case we provide the
+function :any:`ot.emd_1d` and :any:`ot.emd2_1d` to return respectively the OT
+matrix and value. Note that since the solution is very sparse the :code:`sparse`
+parameter of :any:`ot.emd_1d` allows for solving and returning the solution for
+very large problems. Note that in order to compute directly the :math:`W_p`
+Wasserstein distance in 1D we provide the function :any:`ot.wasserstein_1d` that
+takes :code:`p` as a parameter.
+
+Another special case for estimating OT and Monge mapping is between Gaussian
+distributions. In this case there exists a close form solution given in Remark
+2.29 in [15]_ and the Monge mapping is an affine function and can be
+also computed from the covariances and means of the source and target
+distributions. In the case when the finite sample dataset is supposed gaussian, we provide
+:any:`ot.da.OT_mapping_linear` that returns the parameters for the Monge
+mapping.
+
+
+Regularized Optimal Transport
+-----------------------------
+
+Recent developments have shown the interest of regularized OT both in terms of
+computational and statistical properties.
+We address in this section the regularized OT problems that can be expressed as
+
+.. math::
+ \gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \lambda\Omega(\gamma)
+
+ s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
+
+
+where :
+
+- :math:`M\in\mathbb{R}_+^{m\times n}` is the metric cost matrix defining the cost to move mass from bin :math:`a_i` to bin :math:`b_j`.
+- :math:`a` and :math:`b` are histograms (positive, sum to 1) that represent the weights of each samples in the source an target distributions.
+- :math:`\Omega` is the regularization term.
+
+We discuss in the following specific algorithms that can be used depending on
+the regularization term.
+
+
+Entropic regularized OT
+^^^^^^^^^^^^^^^^^^^^^^^
+
+This is the most common regularization used for optimal transport. It has been
+proposed in the ML community by Marco Cuturi in his seminal paper [2]_. This
+regularization has the following expression
+
+.. math::
+ \Omega(\gamma)=\sum_{i,j}\gamma_{i,j}\log(\gamma_{i,j})
+
+
+The use of the regularization term above in the optimization problem has a very
+strong impact. First it makes the problem smooth which leads to new optimization
+procedures such as the well known Sinkhorn algorithm [2]_ or L-BFGS (see
+:any:`ot.smooth` ). Next it makes the problem
+strictly convex meaning that there will be a unique solution. Finally the
+solution of the resulting optimization problem can be expressed as:
+
+.. math::
+
+ \gamma_\lambda^*=\text{diag}(u)K\text{diag}(v)
+
+where :math:`u` and :math:`v` are vectors and :math:`K=\exp(-M/\lambda)` where
+the :math:`\exp` is taken component-wise. In order to solve the optimization
+problem, on can use an alternative projection algorithm called Sinkhorn-Knopp that can be very
+efficient for large values if regularization.
+
+The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and
+:any:`ot.sinkhorn2` that return respectively the OT matrix and the value of the
+linear term. Note that the regularization parameter :math:`\lambda` in the
+equation above is given to those functions with the parameter :code:`reg`.
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+More details about the algorithms used are given in the following note.
+
+.. note::
+ The main function to solve entropic regularized OT is :any:`ot.sinkhorn`.
+ This function is a wrapper and the parameter :code:`method` help you select
+ the actual algorithm used to solve the problem:
+
+ + :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the
+ classic algorithm [2]_.
+ + :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the
+ log stabilized version of the algorithm [9]_.
+ + :code:`method='sinkhorn_epsilon_scaling'` calls
+ :any:`ot.bregman.sinkhorn_epsilon_scaling` the epsilon scaling version
+ of the algorithm [9]_.
+ + :code:`method='greenkhorn'` calls :any:`ot.bregman.greenkhorn` the
+ greedy sinkhorn verison of the algorithm [22]_.
+
+ In addition to all those variants of sinkhorn, we have another
+ implementation solving the problem in the smooth dual or semi-dual in
+ :any:`ot.smooth`. This solver uses the :any:`scipy.optimize.minimize`
+ function to solve the smooth problem with :code:`L-BFGS-B` algorithm. Tu use
+ this solver, use functions :any:`ot.smooth.smooth_ot_dual` or
+ :any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='kl'` to
+ choose entropic/Kullbach Leibler regularization.
+
+
+Recently [23]_ introduced the sinkhorn divergence that build from entropic
+regularization to compute fast and differentiable geometric divergence between
+empirical distributions. Note that we provide a function that compute directly
+(with no need to pre compute the :code:`M` matrix)
+the sinkhorn divergence for empirical distributions in
+:any:`ot.bregman.empirical_sinkhorn_divergence`. Similarly one can compute the
+OT matrix and loss for empirical distributions with respectively
+:any:`ot.bregman.empirical_sinkhorn` and :any:`ot.bregman.empirical_sinkhorn2`.
+
+
+Finally note that we also provide in :any:`ot.stochastic` several implementation
+of stochastic solvers for entropic regularized OT [18]_ [19]_. Those pure Python
+implementations are not optimized for speed but provide a roust implementation
+of algorithms in [18]_ [19]_.
+
+.. hint::
+ Examples of use for :any:`ot.sinkhorn` are available in :
+
+ - :any:`auto_examples/plot_OT_2D_samples`
+ - :any:`auto_examples/plot_OT_1D`
+ - :any:`auto_examples/plot_OT_1D_smooth`
+ - :any:`auto_examples/plot_stochastic`
+
+
+Other regularization
+^^^^^^^^^^^^^^^^^^^^
+
+While entropic OT is the most common and favored in practice, there exist other
+kind of regularization. We provide in POT two specific solvers for other
+regularization terms, namely quadratic regularization and group lasso
+regularization. But we also provide in :any:`ot.optim` two generic solvers that allows solving any
+smooth regularization in practice.
+
+Quadratic regularization
+""""""""""""""""""""""""
+
+The first general regularization term we can solve is the quadratic
+regularization of the form
+
+.. math::
+ \Omega(\gamma)=\sum_{i,j} \gamma_{i,j}^2
+
+this regularization term has a similar effect to entropic regularization in
+densifying the OT matrix but it keeps some sort of sparsity that is lost with
+entropic regularization as soon as :math:`\lambda>0` [17]_. This problem can be
+solved with POT using solvers from :any:`ot.smooth`, more specifically
+functions :any:`ot.smooth.smooth_ot_dual` or
+:any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='l2'` to
+choose the quadratic regularization.
+
+.. hint::
+ Examples of quadratic regularization are available in :
+
+ - :any:`auto_examples/plot_OT_1D_smooth`
+ - :any:`auto_examples/plot_optim_OTreg`
+
+
+
+Group Lasso regularization
+""""""""""""""""""""""""""
+
+Another regularization that has been used in recent years [5]_ is the group lasso
+regularization
+
+.. math::
+ \Omega(\gamma)=\sum_{j,G\in\mathcal{G}} \|\gamma_{G,j}\|_q^p
+
+where :math:`\mathcal{G}` contains non overlapping groups of lines in the OT
+matrix. This regularization proposed in [5]_ will promote sparsity at the group level and for
+instance will force target samples to get mass from a small number of groups.
+Note that the exact OT solution is already sparse so this regularization does
+not make sens if it is not combined with entropic regularization. Depending on
+the choice of :code:`p` and :code:`q`, the problem can be solved with different
+approaches. When :code:`q=1` and :code:`p<1` the problem is non convex but can
+be solved using an efficient majoration minimization approach with
+:any:`ot.sinkhorn_lpl1_mm`. When :code:`q=2` and :code:`p=1` we recover the
+convex group lasso and we provide a solver using generalized conditional
+gradient algorithm [7]_ in function
+:any:`ot.da.sinkhorn_l1l2_gl`.
+
+.. hint::
+ Examples of group Lasso regularization are available in :
+
+ - :any:`auto_examples/plot_otda_classes`
+ - :any:`auto_examples/plot_otda_d2`
+
+
+Generic solvers
+"""""""""""""""
+
+Finally we propose in POT generic solvers that can be used to solve any
+regularization as long as you can provide a function computing the
+regularization and a function computing its gradient (or sub-gradient).
+
+In order to solve
+
+.. math::
+ \gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \lambda\Omega(\gamma)
+
+ s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
+
+you can use function :any:`ot.optim.cg` that will use a conditional gradient as
+proposed in [6]_ . You need to provide the regularization function as parameter
+``f`` and its gradient as parameter ``df``. Note that the conditional gradient relies on
+iterative solving of a linearization of the problem using the exact
+:any:`ot.emd` so it can be slow in practice. But, being an interior point
+algorithm, it always returns a
+transport matrix that does not violates the marginals.
+
+Another generic solver is proposed to solve the problem
+
+.. math::
+ \gamma^* = arg\min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j}+ \lambda_e\Omega_e(\gamma) + \lambda\Omega(\gamma)
+
+ s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
+
+where :math:`\Omega_e` is the entropic regularization. In this case we use a
+generalized conditional gradient [7]_ implemented in :any:`ot.optim.gcg` that
+does not linearize the entropic term but
+relies on :any:`ot.sinkhorn` for its iterations.
+
+.. hint::
+ An example of generic solvers are available in :
+
+ - :any:`auto_examples/plot_optim_OTreg`
+
+
+Wasserstein Barycenters
+-----------------------
+
+A Wasserstein barycenter is a distribution that minimize its Wasserstein
+distance with respect to other distributions [16]_. It corresponds to minimizing the
+following problem by searching a distribution :math:`\mu` such that
+
+.. math::
+ \min_\mu \quad \sum_{k} w_kW(\mu,\mu_k)
+
+
+In practice we model a distribution with a finite number of support position:
+
+.. math::
+ \mu=\sum_{i=1}^n a_i\delta_{x_i}
+
+where :math:`a` is an histogram on the simplex and the :math:`\{x_i\}` are the
+position of the support. We can clearly see here that optimizing :math:`\mu` can
+be done by searching for optimal weights :math:`a` or optimal support
+:math:`\{x_i\}` (optimizing both is also an option).
+We provide in POT solvers to estimate a discrete
+Wasserstein barycenter in both cases.
+
+Barycenters with fixed support
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+When optimizing a barycenter with a fixed support, the optimization problem can
+be expressed as
+
+.. math::
+ \min_a \quad \sum_{k} w_k W(a,b_k)
+
+where :math:`b_k` are also weights in the simplex. In the non-regularized case,
+the problem above is a classical linear program. In this case we propose a
+solver :any:`ot.lp.barycenter` that rely on generic LP solvers. By default the
+function uses :any:`scipy.optimize.linprog`, but more efficient LP solvers from
+cvxopt can be also used by changing parameter :code:`solver`. Note that this problem
+requires to solve a very large linear program and can be very slow in
+practice.
+
+Similarly to the OT problem, OT barycenters can be computed in the regularized
+case. When using entropic regularization is used, the problem can be solved with a
+generalization of the sinkhorn algorithm based on bregman projections [3]_. This
+algorithm is provided in function :any:`ot.bregman.barycenter` also available as
+:any:`ot.barycenter`. In this case, the algorithm scales better to large
+distributions and rely only on matrix multiplications that can be performed in
+parallel.
+
+In addition to the speedup brought by regularization, one can also greatly
+accelerate the estimation of Wasserstein barycenter when the support has a
+separable structure [21]_. In the case of 2D images for instance one can replace
+the matrix vector production in the Bregman projections by convolution
+operators. We provide an implementation of this algorithm in function
+:any:`ot.bregman.convolutional_barycenter2d`.
+
+.. hint::
+ Examples of Wasserstein (:any:`ot.lp.barycenter`) and regularized Wasserstein
+ barycenter (:any:`ot.bregman.barycenter`) computation are available in :
+
+ - :any:`auto_examples/plot_barycenter_1D`
+ - :any:`auto_examples/plot_barycenter_lp_vs_entropic`
+
+ An example of convolutional barycenter
+ (:any:`ot.bregman.convolutional_barycenter2d`) computation
+ for 2D images is available
+ in :
+
+ - :any:`auto_examples/plot_convolutional_barycenter`
+
+
+
+Barycenters with free support
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Estimating the Wasserstein barycenter with free support but fixed weights
+corresponds to solving the following optimization problem:
+
+.. math::
+ \min_{\{x_i\}} \quad \sum_{k} w_kW(\mu,\mu_k)
+
+ s.t. \quad \mu=\sum_{i=1}^n a_i\delta_{x_i}
+
+We provide a solver based on [20]_ in
+:any:`ot.lp.free_support_barycenter`. This function minimize the problem and
+return a locally optimal support :math:`\{x_i\}` for uniform or given weights
+:math:`a`.
+
+ .. hint::
+
+ An example of the free support barycenter estimation is available
+ in :
+
+ - :any:`auto_examples/plot_free_support_barycenter`
+
+
+
+
+Monge mapping and Domain adaptation
+-----------------------------------
+
+The original transport problem investigated by Gaspard Monge was seeking for a
+mapping function that maps (or transports) between a source and target
+distribution but that minimizes the transport loss. The existence and uniqueness of this
+optimal mapping is still an open problem in the general case but has been proven
+for smooth distributions by Brenier in his eponym `theorem
+<https://who.rocq.inria.fr/Jean-David.Benamou/demiheure.pdf>`__. We provide in
+:any:`ot.da` several solvers for smooth Monge mapping estimation and domain
+adaptation from discrete distributions.
+
+Monge Mapping estimation
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+We now discuss several approaches that are implemented in POT to estimate or
+approximate a Monge mapping from finite distributions.
+
+First note that when the source and target distributions are supposed to be Gaussian
+distributions, there exists a close form solution for the mapping and its an
+affine function [14]_ of the form :math:`T(x)=Ax+b` . In this case we provide the function
+:any:`ot.da.OT_mapping_linear` that return the operator :math:`A` and vector
+:math:`b`. Note that if the number of samples is too small there is a parameter
+:code:`reg` that provide a regularization for the covariance matrix estimation.
+
+For a more general mapping estimation we also provide the barycentric mapping
+proposed in [6]_ . It is implemented in the class :any:`ot.da.EMDTransport` and
+other transport based classes in :any:`ot.da` . Those classes are discussed more
+in the following but follow an interface similar to sklearn classes. Finally a
+method proposed in [8]_ that estimates a continuous mapping approximating the
+barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for
+linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non linear mapping.
+
+ .. hint::
+
+ An example of the linear Monge mapping estimation is available
+ in :
+
+ - :any:`auto_examples/plot_otda_linear_mapping`
+
+Domain adaptation classes
+^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The use of OT for domain adaptation (OTDA) has been first proposed in [5]_ that also
+introduced the group Lasso regularization. The main idea of OTDA is to estimate
+a mapping of the samples between source and target distributions which allows to
+transport labeled source samples onto the target distribution with no labels.
+
+We provide several classes based on :any:`ot.da.BaseTransport` that provide
+several OT and mapping estimations. The interface of those classes is similar to
+classifiers in sklearn toolbox. At initialization, several parameters such as
+ regularization parameter value can be set. Then one needs to estimate the
+mapping with function :any:`ot.da.BaseTransport.fit`. Finally one can map the
+samples from source to target with :any:`ot.da.BaseTransport.transform` and
+from target to source with :any:`ot.da.BaseTransport.inverse_transform`.
+
+Here is
+an example for class :any:`ot.da.EMDTransport` :
+
+.. code::
+
+ ot_emd = ot.da.EMDTransport()
+ ot_emd.fit(Xs=Xs, Xt=Xt)
+
+ Mapped_Xs= ot_emd.transform(Xs=Xs)
+
+A list of the provided implementation is given in the following note.
+
+.. note::
+
+ Here is a list of the OT mapping classes inheriting from
+ :any:`ot.da.BaseTransport`
+
+ * :any:`ot.da.EMDTransport` : Barycentric mapping with EMD transport
+ * :any:`ot.da.SinkhornTransport` : Barycentric mapping with Sinkhorn transport
+ * :any:`ot.da.SinkhornL1l2Transport` : Barycentric mapping with Sinkhorn +
+ group Lasso regularization [5]_
+ * :any:`ot.da.SinkhornLpl1Transport` : Barycentric mapping with Sinkhorn +
+ non convex group Lasso regularization [5]_
+ * :any:`ot.da.LinearTransport` : Linear mapping estimation between Gaussians
+ [14]_
+ * :any:`ot.da.MappingTransport` : Nonlinear mapping estimation [8]_
+
+.. hint::
+
+ Example of the use of OTDA classes are available in :
+
+ - :any:`auto_examples/plot_otda_color_images`
+ - :any:`auto_examples/plot_otda_mapping`
+ - :any:`auto_examples/plot_otda_mapping_colors_images`
+ - :any:`auto_examples/plot_otda_semi_supervised`
+
+Other applications
+------------------
+
+We discuss in the following several OT related problems and tools that has been
+proposed in the OT and machine learning community.
+
+Wasserstein Discriminant Analysis
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Wasserstein Discriminant Analysis [11]_ is a generalization of `Fisher Linear Discriminant
+Analysis <https://en.wikipedia.org/wiki/Linear_discriminant_analysis>`__ that
+allows discrimination between classes that are not linearly separable. It
+consist in finding a linear projector optimizing the following criterion
+
+.. math::
+ P = \text{arg}\min_P \frac{\sum_i OT_e(\mu_i\#P,\mu_i\#P)}{\sum_{i,j\neq i}
+ OT_e(\mu_i\#P,\mu_j\#P)}
+
+where :math:`\#` is the push-forward operator, :math:`OT_e` is the entropic OT
+loss and :math:`\mu_i` is the
+distribution of samples from class :math:`i`. :math:`P` is also constrained to
+be in the Stiefel manifold. WDA can be solved in POT using function
+:any:`ot.dr.wda`. It requires to have installed :code:`pymanopt` and
+:code:`autograd` for manifold optimization and automatic differentiation
+respectively. Note that we also provide the Fisher discriminant estimator in
+:any:`ot.dr.fda` for easy comparison.
+
+.. warning::
+ Note that due to the hard dependency on :code:`pymanopt` and
+ :code:`autograd`, :any:`ot.dr` is not imported by default. If you want to
+ use it you have to specifically import it with :code:`import ot.dr` .
+
+.. hint::
+
+ An example of the use of WDA is available in :
+
+ - :any:`auto_examples/plot_WDA`
+
+
+Unbalanced optimal transport
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Unbalanced OT is a relaxation of the entropy regularized OT problem where the violation of
+the constraint on the marginals is added to the objective of the optimization
+problem. The unbalanced OT metric between two unbalanced histograms a and b is defined as [25]_ [10]_:
+
+.. math::
+ W_u(a, b) = \min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
+
+ s.t. \quad \gamma\geq 0
+
+
+where KL is the Kullback-Leibler divergence. This formulation allows for
+computing approximate mapping between distributions that do not have the same
+amount of mass. Interestingly the problem can be solved with a generalization of
+the Bregman projections algorithm [10]_. We provide a solver for unbalanced OT
+in :any:`ot.unbalanced`. Computing the optimal transport
+plan or the transport cost is similar to the balanced case. The Sinkhorn-Knopp
+algorithm is implemented in :any:`ot.sinkhorn_unbalanced` and :any:`ot.sinkhorn_unbalanced2`
+that return respectively the OT matrix and the value of the
+linear term.
+
+.. note::
+ The main function to solve entropic regularized UOT is :any:`ot.sinkhorn_unbalanced`.
+ This function is a wrapper and the parameter :code:`method` helps you select
+ the actual algorithm used to solve the problem:
+
+ + :code:`method='sinkhorn'` calls :any:`ot.unbalanced.sinkhorn_knopp_unbalanced`
+ the generalized Sinkhorn algorithm [25]_ [10]_.
+ + :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.sinkhorn_stabilized_unbalanced`
+ the log stabilized version of the algorithm [10]_.
+
+
+.. hint::
+
+ Examples of the use of :any:`ot.sinkhorn_unbalanced` are available in :
+
+ - :any:`auto_examples/plot_UOT_1D`
+
+
+Unbalanced Barycenters
+^^^^^^^^^^^^^^^^^^^^^^
+
+As with balanced distributions, we can define a barycenter of a set of
+histograms with different masses as a Fréchet Mean:
+
+ .. math::
+ \min_{\mu} \quad \sum_{k} w_kW_u(\mu,\mu_k)
+
+Where :math:`W_u` is the unbalanced Wasserstein metric defined above. This problem
+can also be solved using generalized version of Sinkhorn's algorithm and it is
+implemented the main function :any:`ot.barycenter_unbalanced`.
+
+
+.. note::
+ The main function to compute UOT barycenters is :any:`ot.barycenter_unbalanced`.
+ This function is a wrapper and the parameter :code:`method` help you select
+ the actual algorithm used to solve the problem:
+
+ + :code:`method='sinkhorn'` calls :any:`ot.unbalanced.barycenter_unbalanced_sinkhorn_unbalanced`
+ the generalized Sinkhorn algorithm [10]_.
+ + :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.barycenter_unbalanced_stabilized`
+ the log stabilized version of the algorithm [10]_.
+
+
+.. hint::
+
+ Examples of the use of :any:`ot.barycenter_unbalanced` are available in :
+
+ - :any:`auto_examples/plot_UOT_barycenter_1D`
+
+
+Gromov-Wasserstein
+^^^^^^^^^^^^^^^^^^
+
+Gromov Wasserstein (GW) is a generalization of OT to distributions that do not lie in
+the same space [13]_. In this case one cannot compute distance between samples
+from the two distributions. [13]_ proposed instead to realign the metric spaces
+by computing a transport between distance matrices. The Gromow Wasserstein
+alignement between two distributions can be expressed as the one minimizing:
+
+.. math::
+ GW = \min_\gamma \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*\gamma_{i,j}*\gamma_{k,l}
+
+ s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
+
+where ::math:`C1` is the distance matrix between samples in the source
+distribution and :math:`C2` the one between samples in the target,
+:math:`L(C1_{i,k},C2_{j,l})` is a measure of similarity between
+:math:`C1_{i,k}` and :math:`C2_{j,l}` often chosen as
+:math:`L(C1_{i,k},C2_{j,l})=\|C1_{i,k}-C2_{j,l}\|^2`. The optimization problem
+above is a non-convex quadratic program but we provide a solver that finds a
+local minimum using conditional gradient in :any:`ot.gromov.gromov_wasserstein`.
+There also exists an entropic regularized variant of GW that has been proposed in
+[12]_ and we provide an implementation of their algorithm in
+:any:`ot.gromov.entropic_gromov_wasserstein`.
+
+Note that similarly to Wasserstein distance GW allows for the definition of GW
+barycenters that can be expressed as
+
+.. math::
+ \min_{C\geq 0} \quad \sum_{k} w_k GW(C,Ck)
+
+where :math:`Ck` is the distance matrix between samples in distribution
+:math:`k`. Note that interestingly the barycenter is defined as a symmetric
+positive matrix. We provide a block coordinate optimization procedure in
+:any:`ot.gromov.gromov_barycenters` and
+:any:`ot.gromov.entropic_gromov_barycenters` for non-regularized and regularized
+barycenters respectively.
+
+Finally note that recently a fusion between Wasserstein and GW, coined Fused
+Gromov-Wasserstein (FGW) has been proposed
+in [24]_. It allows to compute a similarity between objects that are only partly in
+the same space. As such it can be used to measure similarity between labeled
+graphs for instance and also provide computable barycenters.
+The implementations of FGW and FGW barycenter is provided in functions
+:any:`ot.gromov.fused_gromov_wasserstein` and :any:`ot.gromov.fgw_barycenters`.
+
+.. hint::
+
+ Examples of computation of GW, regularized G and FGW are available in :
+
+ - :any:`auto_examples/plot_gromov`
+ - :any:`auto_examples/plot_fgw`
+
+ Examples of GW, regularized GW and FGW barycenters are available in :
+
+ - :any:`auto_examples/plot_gromov_barycenter`
+ - :any:`auto_examples/plot_barycenter_fgw`
+
+
+GPU acceleration
+^^^^^^^^^^^^^^^^
+
+We provide several implementation of our OT solvers in :any:`ot.gpu`. Those
+implementations use the :code:`cupy` toolbox that obviously need to be installed.
+
+
+.. note::
+
+ Several implementations of POT functions (mainly those relying on linear
+ algebra) have been implemented in :any:`ot.gpu`. Here is a short list on the
+ main entries:
+
+ - :any:`ot.gpu.dist` : computation of distance matrix
+ - :any:`ot.gpu.sinkhorn` : computation of sinkhorn
+ - :any:`ot.gpu.sinkhorn_lpl1_mm` : computation of sinkhorn + group lasso
+
+Note that while the :any:`ot.gpu` module has been designed to be compatible with
+POT, calling its function with :any:`numpy` arrays will incur a large overhead due to
+the memory copy of the array on GPU prior to computation and conversion of the
+array after computation. To avoid this overhead, we provide functions
+:any:`ot.gpu.to_gpu` and :any:`ot.gpu.to_np` that perform the conversion
+explicitly.
+
+
+.. warning::
+ Note that due to the hard dependency on :code:`cupy`, :any:`ot.gpu` is not
+ imported by default. If you want to
+ use it you have to specifically import it with :code:`import ot.gpu` .
+
+
+FAQ
+---
+
+
+
+1. **How to solve a discrete optimal transport problem ?**
+
+ The solver for discrete OT is the function :py:mod:`ot.emd` that returns
+ the OT transport matrix. If you want to solve a regularized OT you can
+ use :py:mod:`ot.sinkhorn`.
+
+
+ Here is a simple use case:
+
+ .. code:: python
+
+ # a,b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T=ot.emd(a,b,M) # exact linear program
+ T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
+
+ More detailed examples can be seen on this example:
+ :doc:`auto_examples/plot_OT_2D_samples`
+
+
+2. **pip install POT fails with error : ImportError: No module named Cython.Build**
+
+ As discussed shortly in the README file. POT requires to have :code:`numpy`
+ and :code:`cython` installed to build. This corner case is not yet handled
+ by :code:`pip` and for now you need to install both library prior to
+ installing POT.
+
+ Note that this problem do not occur when using conda-forge since the packages
+ there are pre-compiled.
+
+ See `Issue #59 <https://github.com/rflamary/POT/issues/59>`__ for more
+ details.
+
+3. **Why is Sinkhorn slower than EMD ?**
+
+ This might come from the choice of the regularization term. The speed of
+ convergence of sinkhorn depends directly on this term [22]_ and when the
+ regularization gets very small the problem try and approximate the exact OT
+ which leads to slow convergence in addition to numerical problems. In other
+ words, for large regularization sinkhorn will be very fast to converge, for
+ small regularization (when you need an OT matrix close to the true OT), it
+ might be quicker to use the EMD solver.
+
+ Also note that the numpy implementation of the sinkhorn can use parallel
+ computation depending on the configuration of your system but very important
+ speedup can be obtained by using a GPU implementation since all operations
+ are matrix/vector products.
+
+4. **Using GPU fails with error: module 'ot' has no attribute 'gpu'**
+
+ In order to limit import time and hard dependencies in POT. we do not import
+ some sub-modules automatically with :code:`import ot`. In order to use the
+ acceleration in :any:`ot.gpu` you need first to import is with
+ :code:`import ot.gpu`.
+
+ See `Issue #85 <https://github.com/rflamary/POT/issues/85>`__ and :any:`ot.gpu`
+ for more details.
+
+
+References
+----------
+
+.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
+ December). `Displacement nterpolation using Lagrangian mass transport
+ <https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf>`__.
+ In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
+
+.. [2] Cuturi, M. (2013). `Sinkhorn distances: Lightspeed computation of
+ optimal transport <https://arxiv.org/pdf/1306.0895.pdf>`__. In Advances
+ in Neural Information Processing Systems (pp. 2292-2300).
+
+.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+ (2015). `Iterative Bregman projections for regularized transportation
+ problems <https://arxiv.org/pdf/1412.5154.pdf>`__. SIAM Journal on
+ Scientific Computing, 37(2), A1111-A1138.
+
+.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
+ `Supervised planetary unmixing with optimal
+ transport <https://hal.archives-ouvertes.fr/hal-01377236/document>`__,
+ Whorkshop on Hyperspectral Image and Signal Processing : Evolution in
+ Remote Sensing (WHISPERS), 2016.
+
+.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, `Optimal Transport
+ for Domain Adaptation <https://arxiv.org/pdf/1507.00504.pdf>`__, in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence , vol.PP,
+ no.99, pp.1-1
+
+.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ `Regularized discrete optimal
+ transport <https://arxiv.org/pdf/1307.5551.pdf>`__. SIAM Journal on
+ Imaging Sciences, 7(3), 1853-1882.
+
+.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). `Generalized
+ conditional gradient: analysis of convergence and
+ applications <https://arxiv.org/pdf/1510.06567.pdf>`__. arXiv preprint
+ arXiv:1510.06567.
+
+.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), `Mapping
+ estimation for discrete optimal
+ transport <http://remi.flamary.com/biblio/perrot2016mapping.pdf>`__,
+ Neural Information Processing Systems (NIPS).
+
+.. [9] Schmitzer, B. (2016). `Stabilized Sparse Scaling Algorithms for
+ Entropy Regularized Transport
+ Problems <https://arxiv.org/pdf/1610.06519.pdf>`__. arXiv preprint
+ arXiv:1610.06519.
+
+.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+ `Scaling algorithms for unbalanced transport
+ problems <https://arxiv.org/pdf/1607.05816.pdf>`__. arXiv preprint
+ arXiv:1607.05816.
+
+.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+ `Wasserstein Discriminant
+ Analysis <https://arxiv.org/pdf/1608.08063.pdf>`__. arXiv preprint
+ arXiv:1608.08063.
+
+.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
+ `Gromov-Wasserstein averaging of kernel and distance
+ matrices <http://proceedings.mlr.press/v48/peyre16.html>`__
+ International Conference on Machine Learning (ICML).
+
+.. [13] Mémoli, Facundo (2011). `Gromov–Wasserstein distances and the
+ metric approach to object
+ matching <https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf>`__.
+ Foundations of computational mathematics 11.4 : 417-487.
+
+.. [14] Knott, M. and Smith, C. S. (1984). `On the optimal mapping of
+ distributions <https://link.springer.com/article/10.1007/BF00934745>`__,
+ Journal of Optimization Theory and Applications Vol 43.
+
+.. [15] Peyré, G., & Cuturi, M. (2018). `Computational Optimal
+ Transport <https://arxiv.org/pdf/1803.00567.pdf>`__ .
+
+.. [16] Agueh, M., & Carlier, G. (2011). `Barycenters in the Wasserstein
+ space <https://hal.archives-ouvertes.fr/hal-00637399/document>`__. SIAM
+ Journal on Mathematical Analysis, 43(2), 904-924.
+
+.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). `Smooth and Sparse
+ Optimal Transport <https://arxiv.org/abs/1710.06276>`__. Proceedings of
+ the Twenty-First International Conference on Artificial Intelligence and
+ Statistics (AISTATS).
+
+.. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) `Stochastic
+ Optimization for Large-scale Optimal
+ Transport <https://arxiv.org/abs/1605.08527>`__. Advances in Neural
+ Information Processing Systems (2016).
+
+.. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet,
+ A.& Blondel, M. `Large-scale Optimal Transport and Mapping
+ Estimation <https://arxiv.org/pdf/1711.02283.pdf>`__. International
+ Conference on Learning Representation (2018)
+
+.. [20] Cuturi, M. and Doucet, A. (2014) `Fast Computation of Wasserstein
+ Barycenters <http://proceedings.mlr.press/v32/cuturi14.html>`__.
+ International Conference in Machine Learning
+
+.. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A.,
+ Nguyen, A. & Guibas, L. (2015). `Convolutional wasserstein distances:
+ Efficient optimal transportation on geometric
+ domains <https://dl.acm.org/citation.cfm?id=2766963>`__. ACM
+ Transactions on Graphics (TOG), 34(4), 66.
+
+.. [22] J. Altschuler, J.Weed, P. Rigollet, (2017) `Near-linear time
+ approximation algorithms for optimal transport via Sinkhorn
+ iteration <https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf>`__,
+ Advances in Neural Information Processing Systems (NIPS) 31
+
+.. [23] Aude, G., Peyré, G., Cuturi, M., `Learning Generative Models with
+ Sinkhorn Divergences <https://arxiv.org/abs/1706.00292>`__, Proceedings
+ of the Twenty-First International Conference on Artficial Intelligence
+ and Statistics, (AISTATS) 21, 2018
+
+.. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N.
+ (2019). `Optimal Transport for structured data with application on
+ graphs <http://proceedings.mlr.press/v97/titouan19a.html>`__ Proceedings
+ of the 36th International Conference on Machine Learning (ICML).
+
+.. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
+ Learning with a Wasserstein Loss, Advances in Neural Information
+ Processing Systems (NIPS) 2015
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
new file mode 100644
index 0000000..0871779
--- /dev/null
+++ b/docs/source/readme.rst
@@ -0,0 +1,407 @@
+POT: Python Optimal Transport
+=============================
+
+|PyPI version| |Anaconda Cloud| |Build Status| |Documentation Status|
+|Downloads| |Anaconda downloads| |License|
+
+This open source Python library provide several solvers for optimization
+problems related to Optimal Transport for signal, image processing and
+machine learning.
+
+It provides the following solvers:
+
+- OT Network Flow solver for the linear program/ Earth Movers Distance
+ [1].
+- Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2],
+ stabilized version [9][10] and greedy Sinkhorn [22] with optional GPU
+ implementation (requires cupy).
+- Sinkhorn divergence [23] and entropic regularization OT from
+ empirical data.
+- Smooth optimal transport solvers (dual and semi-dual) for KL and
+ squared L2 regularizations [17].
+- Non regularized Wasserstein barycenters [16] with LP solver (only
+ small scale).
+- Bregman projections for Wasserstein barycenter [3], convolutional
+ barycenter [21] and unmixing [4].
+- Optimal transport for domain adaptation with group lasso
+ regularization [5]
+- Conditional gradient [6] and Generalized conditional gradient for
+ regularized OT [7].
+- Linear OT [14] and Joint OT matrix and mapping estimation [8].
+- Wasserstein Discriminant Analysis [11] (requires autograd +
+ pymanopt).
+- Gromov-Wasserstein distances and barycenters ([13] and regularized
+ [12])
+- Stochastic Optimization for Large-scale Optimal Transport (semi-dual
+ problem [18] and dual problem [19])
+- Non regularized free support Wasserstein barycenters [20].
+- Unbalanced OT with KL relaxation distance and barycenter [10, 25].
+
+Some demonstrations (both in Python and Jupyter Notebook format) are
+available in the examples folder.
+
+Using and citing the toolbox
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+If you use this toolbox in your research and find it useful, please cite
+POT using the following bibtex reference:
+
+::
+
+ @misc{flamary2017pot,
+ title={POT Python Optimal Transport library},
+ author={Flamary, R{'e}mi and Courty, Nicolas},
+ url={https://github.com/rflamary/POT},
+ year={2017}
+ }
+
+Installation
+------------
+
+The library has been tested on Linux, MacOSX and Windows. It requires a
+C++ compiler for using the EMD solver and relies on the following Python
+modules:
+
+- Numpy (>=1.11)
+- Scipy (>=1.0)
+- Cython (>=0.23)
+- Matplotlib (>=1.5)
+
+Pip installation
+^^^^^^^^^^^^^^^^
+
+Note that due to a limitation of pip, ``cython`` and ``numpy`` need to
+be installed prior to installing POT. This can be done easily with
+
+::
+
+ pip install numpy cython
+
+You can install the toolbox through PyPI with:
+
+::
+
+ pip install POT
+
+or get the very latest version by downloading it and then running:
+
+::
+
+ python setup.py install --user # for user install (no root)
+
+Anaconda installation with conda-forge
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+If you use the Anaconda python distribution, POT is available in
+`conda-forge <https://conda-forge.org>`__. To install it and the
+required dependencies:
+
+::
+
+ conda install -c conda-forge pot
+
+Post installation check
+^^^^^^^^^^^^^^^^^^^^^^^
+
+After a correct installation, you should be able to import the module
+without errors:
+
+.. code:: python
+
+ import ot
+
+Note that for easier access the module is name ot instead of pot.
+
+Dependencies
+~~~~~~~~~~~~
+
+Some sub-modules require additional dependences which are discussed
+below
+
+- **ot.dr** (Wasserstein dimensionality reduction) depends on autograd
+ and pymanopt that can be installed with:
+
+ ::
+
+ pip install pymanopt autograd
+
+- **ot.gpu** (GPU accelerated OT) depends on cupy that have to be
+ installed following instructions on `this
+ page <https://docs-cupy.chainer.org/en/stable/install.html>`__.
+
+obviously you need CUDA installed and a compatible GPU.
+
+Examples
+--------
+
+Short examples
+~~~~~~~~~~~~~~
+
+- Import the toolbox
+
+ .. code:: python
+
+ import ot
+
+- Compute Wasserstein distances
+
+ .. code:: python
+
+ # a,b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ Wd=ot.emd2(a,b,M) # exact linear program
+ Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
+ # if b is a matrix compute all distances to a and return a vector
+
+- Compute OT matrix
+
+ .. code:: python
+
+ # a,b are 1D histograms (sum to 1 and positive)
+ # M is the ground cost matrix
+ T=ot.emd(a,b,M) # exact linear program
+ T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
+
+- Compute Wasserstein barycenter
+
+ .. code:: python
+
+ # A is a n*d matrix containing d 1D histograms
+ # M is the ground cost matrix
+ ba=ot.barycenter(A,M,reg) # reg is regularization parameter
+
+Examples and Notebooks
+~~~~~~~~~~~~~~~~~~~~~~
+
+The examples folder contain several examples and use case for the
+library. The full documentation is available on
+`Readthedocs <http://pot.readthedocs.io/>`__.
+
+Here is a list of the Python notebooks available
+`here <https://github.com/rflamary/POT/blob/master/notebooks/>`__ if you
+want a quick look:
+
+- `1D optimal
+ transport <https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_1D.ipynb>`__
+- `OT Ground
+ Loss <https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_L1_vs_L2.ipynb>`__
+- `Multiple EMD
+ computation <https://github.com/rflamary/POT/blob/master/notebooks/plot_compute_emd.ipynb>`__
+- `2D optimal transport on empirical
+ distributions <https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_2D_samples.ipynb>`__
+- `1D Wasserstein
+ barycenter <https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_1D.ipynb>`__
+- `OT with user provided
+ regularization <https://github.com/rflamary/POT/blob/master/notebooks/plot_optim_OTreg.ipynb>`__
+- `Domain adaptation with optimal
+ transport <https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_d2.ipynb>`__
+- `Color transfer in
+ images <https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_color_images.ipynb>`__
+- `OT mapping estimation for domain
+ adaptation <https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_mapping.ipynb>`__
+- `OT mapping estimation for color transfer in
+ images <https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_mapping_colors_images.ipynb>`__
+- `Wasserstein Discriminant
+ Analysis <https://github.com/rflamary/POT/blob/master/notebooks/plot_WDA.ipynb>`__
+- `Gromov
+ Wasserstein <https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov.ipynb>`__
+- `Gromov Wasserstein
+ Barycenter <https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov_barycenter.ipynb>`__
+
+You can also see the notebooks with `Jupyter
+nbviewer <https://nbviewer.jupyter.org/github/rflamary/POT/tree/master/notebooks/>`__.
+
+Acknowledgements
+----------------
+
+This toolbox has been created and is maintained by
+
+- `Rémi Flamary <http://remi.flamary.com/>`__
+- `Nicolas Courty <http://people.irisa.fr/Nicolas.Courty/>`__
+
+The contributors to this library are
+
+- `Alexandre Gramfort <http://alexandre.gramfort.net/>`__
+- `Laetitia Chapel <http://people.irisa.fr/Laetitia.Chapel/>`__
+- `Michael Perrot <http://perso.univ-st-etienne.fr/pem82055/>`__
+ (Mapping estimation)
+- `Léo Gautheron <https://github.com/aje>`__ (GPU implementation)
+- `Nathalie
+ Gayraud <https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1>`__
+- `Stanislas Chambon <https://slasnista.github.io/>`__
+- `Antoine Rolet <https://arolet.github.io/>`__
+- Erwan Vautier (Gromov-Wasserstein)
+- `Kilian Fatras <https://kilianfatras.github.io/>`__
+- `Alain
+ Rakotomamonjy <https://sites.google.com/site/alainrakotomamonjy/home>`__
+- `Vayer Titouan <https://tvayer.github.io/>`__
+- `Hicham Janati <https://hichamjanati.github.io/>`__ (Unbalanced OT)
+- `Romain Tavenard <https://rtavenar.github.io/>`__ (1d Wasserstein)
+
+This toolbox benefit a lot from open source research and we would like
+to thank the following persons for providing some code (in various
+languages):
+
+- `Gabriel Peyré <http://gpeyre.github.io/>`__ (Wasserstein Barycenters
+ in Matlab)
+- `Nicolas Bonneel <http://liris.cnrs.fr/~nbonneel/>`__ ( C++ code for
+ EMD)
+- `Marco Cuturi <http://marcocuturi.net/>`__ (Sinkhorn Knopp in
+ Matlab/Cuda)
+
+Contributions and code of conduct
+---------------------------------
+
+Every contribution is welcome and should respect the `contribution
+guidelines <CONTRIBUTING.md>`__. Each member of the project is expected
+to follow the `code of conduct <CODE_OF_CONDUCT.md>`__.
+
+Support
+-------
+
+You can ask questions and join the development discussion:
+
+- On the `POT Slack channel <https://pot-toolbox.slack.com>`__
+- On the POT `mailing
+ list <https://mail.python.org/mm3/mailman3/lists/pot.python.org/>`__
+
+You can also post bug reports and feature requests in Github issues.
+Make sure to read our `guidelines <CONTRIBUTING.md>`__ first.
+
+References
+----------
+
+[1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
+December). `Displacement interpolation using Lagrangian mass
+transport <https://people.csail.mit.edu/sparis/publi/2011/sigasia/Bonneel_11_Displacement_Interpolation.pdf>`__.
+In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
+
+[2] Cuturi, M. (2013). `Sinkhorn distances: Lightspeed computation of
+optimal transport <https://arxiv.org/pdf/1306.0895.pdf>`__. In Advances
+in Neural Information Processing Systems (pp. 2292-2300).
+
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
+(2015). `Iterative Bregman projections for regularized transportation
+problems <https://arxiv.org/pdf/1412.5154.pdf>`__. SIAM Journal on
+Scientific Computing, 37(2), A1111-A1138.
+
+[4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
+`Supervised planetary unmixing with optimal
+transport <https://hal.archives-ouvertes.fr/hal-01377236/document>`__,
+Whorkshop on Hyperspectral Image and Signal Processing : Evolution in
+Remote Sensing (WHISPERS), 2016.
+
+[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, `Optimal Transport
+for Domain Adaptation <https://arxiv.org/pdf/1507.00504.pdf>`__, in IEEE
+Transactions on Pattern Analysis and Machine Intelligence , vol.PP,
+no.99, pp.1-1
+
+[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+`Regularized discrete optimal
+transport <https://arxiv.org/pdf/1307.5551.pdf>`__. SIAM Journal on
+Imaging Sciences, 7(3), 1853-1882.
+
+[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). `Generalized
+conditional gradient: analysis of convergence and
+applications <https://arxiv.org/pdf/1510.06567.pdf>`__. arXiv preprint
+arXiv:1510.06567.
+
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), `Mapping
+estimation for discrete optimal
+transport <http://remi.flamary.com/biblio/perrot2016mapping.pdf>`__,
+Neural Information Processing Systems (NIPS).
+
+[9] Schmitzer, B. (2016). `Stabilized Sparse Scaling Algorithms for
+Entropy Regularized Transport
+Problems <https://arxiv.org/pdf/1610.06519.pdf>`__. arXiv preprint
+arXiv:1610.06519.
+
+[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
+`Scaling algorithms for unbalanced transport
+problems <https://arxiv.org/pdf/1607.05816.pdf>`__. arXiv preprint
+arXiv:1607.05816.
+
+[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+`Wasserstein Discriminant
+Analysis <https://arxiv.org/pdf/1608.08063.pdf>`__. arXiv preprint
+arXiv:1608.08063.
+
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
+`Gromov-Wasserstein averaging of kernel and distance
+matrices <http://proceedings.mlr.press/v48/peyre16.html>`__
+International Conference on Machine Learning (ICML).
+
+[13] Mémoli, Facundo (2011). `Gromov–Wasserstein distances and the
+metric approach to object
+matching <https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf>`__.
+Foundations of computational mathematics 11.4 : 417-487.
+
+[14] Knott, M. and Smith, C. S. (1984).`On the optimal mapping of
+distributions <https://link.springer.com/article/10.1007/BF00934745>`__,
+Journal of Optimization Theory and Applications Vol 43.
+
+[15] Peyré, G., & Cuturi, M. (2018). `Computational Optimal
+Transport <https://arxiv.org/pdf/1803.00567.pdf>`__ .
+
+[16] Agueh, M., & Carlier, G. (2011). `Barycenters in the Wasserstein
+space <https://hal.archives-ouvertes.fr/hal-00637399/document>`__. SIAM
+Journal on Mathematical Analysis, 43(2), 904-924.
+
+[17] Blondel, M., Seguy, V., & Rolet, A. (2018). `Smooth and Sparse
+Optimal Transport <https://arxiv.org/abs/1710.06276>`__. Proceedings of
+the Twenty-First International Conference on Artificial Intelligence and
+Statistics (AISTATS).
+
+[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) `Stochastic
+Optimization for Large-scale Optimal
+Transport <https://arxiv.org/abs/1605.08527>`__. Advances in Neural
+Information Processing Systems (2016).
+
+[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet,
+A.& Blondel, M. `Large-scale Optimal Transport and Mapping
+Estimation <https://arxiv.org/pdf/1711.02283.pdf>`__. International
+Conference on Learning Representation (2018)
+
+[20] Cuturi, M. and Doucet, A. (2014) `Fast Computation of Wasserstein
+Barycenters <http://proceedings.mlr.press/v32/cuturi14.html>`__.
+International Conference in Machine Learning
+
+[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A.,
+Nguyen, A. & Guibas, L. (2015). `Convolutional wasserstein distances:
+Efficient optimal transportation on geometric
+domains <https://dl.acm.org/citation.cfm?id=2766963>`__. ACM
+Transactions on Graphics (TOG), 34(4), 66.
+
+[22] J. Altschuler, J.Weed, P. Rigollet, (2017) `Near-linear time
+approximation algorithms for optimal transport via Sinkhorn
+iteration <https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf>`__,
+Advances in Neural Information Processing Systems (NIPS) 31
+
+[23] Aude, G., Peyré, G., Cuturi, M., `Learning Generative Models with
+Sinkhorn Divergences <https://arxiv.org/abs/1706.00292>`__, Proceedings
+of the Twenty-First International Conference on Artficial Intelligence
+and Statistics, (AISTATS) 21, 2018
+
+[24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N.
+(2019). `Optimal Transport for structured data with application on
+graphs <http://proceedings.mlr.press/v97/titouan19a.html>`__ Proceedings
+of the 36th International Conference on Machine Learning (ICML).
+
+[25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2019).
+`Learning with a Wasserstein Loss <http://cbcl.mit.edu/wasserstein/>`__
+Advances in Neural Information Processing Systems (NIPS).
+
+.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
+ :target: https://badge.fury.io/py/POT
+.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
+ :target: https://anaconda.org/conda-forge/pot
+.. |Build Status| image:: https://travis-ci.org/rflamary/POT.svg?branch=master
+ :target: https://travis-ci.org/rflamary/POT
+.. |Documentation Status| image:: https://readthedocs.org/projects/pot/badge/?version=latest
+ :target: http://pot.readthedocs.io/en/latest/?badge=latest
+.. |Downloads| image:: https://pepy.tech/badge/pot
+ :target: https://pepy.tech/project/pot
+.. |Anaconda downloads| image:: https://anaconda.org/conda-forge/pot/badges/downloads.svg
+ :target: https://anaconda.org/conda-forge/pot
+.. |License| image:: https://anaconda.org/conda-forge/pot/badges/license.svg
+ :target: https://github.com/rflamary/POT/blob/master/LICENSE