summaryrefslogtreecommitdiff
path: root/docs/source
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source')
-rw-r--r--docs/source/auto_examples/auto_examples_jupyter.zipbin51426 -> 89148 bytes
-rw-r--r--docs/source/auto_examples/auto_examples_python.zipbin35678 -> 59370 bytes
-rw-r--r--docs/source/auto_examples/demo_OT_1D_test.ipynb54
-rw-r--r--docs/source/auto_examples/demo_OT_1D_test.py71
-rw-r--r--docs/source/auto_examples/demo_OT_1D_test.rst99
-rw-r--r--docs/source/auto_examples/demo_OT_2D_sampleslarge.ipynb54
-rw-r--r--docs/source/auto_examples/demo_OT_2D_sampleslarge.py78
-rw-r--r--docs/source/auto_examples/demo_OT_2D_sampleslarge.rst106
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_001.pngbin52753 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_002.pngbin87798 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_003.pngbin167396 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_004.pngbin82929 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_001.pngbin53561 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_004.pngbin193523 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_001.pngbin237854 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_002.pngbin472911 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_001.pngbin44168 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_002.pngbin111565 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_001.pngbin237854 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_002.pngbin429859 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.pngbin27639 -> 21303 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.pngbin25126 -> 21334 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.pngbin19634 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.pngbin21449 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_005.pngbin0 -> 16995 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_1D_007.pngbin0 -> 18923 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.pngbin22199 -> 20832 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.pngbin21036 -> 20827 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.pngbin9632 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.pngbin91630 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.pngbin9495 -> 9613 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.pngbin23476 -> 82797 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.pngbin0 -> 14508 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.pngbin0 -> 95761 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.pngbin22451 -> 11710 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.pngbin32795 -> 17184 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.pngbin38958 -> 38780 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.pngbin17324 -> 11710 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.pngbin28210 -> 38780 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.pngbin77009 -> 38780 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.pngbin0 -> 14117 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.pngbin0 -> 18696 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_009.pngbin0 -> 21300 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.pngbin0 -> 21300 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_WDA_001.pngbin42791 -> 56604 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_WDA_003.pngbin0 -> 87031 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.pngbin35837 -> 20512 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.pngbin59327 -> 41555 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.pngbin132247 -> 41555 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.pngbin125411 -> 105696 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_005.pngbin0 -> 108687 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_006.pngbin0 -> 105696 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.pngbin146617 -> 162612 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_compute_emd_002.pngbin38746 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_compute_emd_003.pngbin0 -> 29276 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_compute_emd_004.pngbin0 -> 38748 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_gromov_001.pngbin0 -> 46633 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_gromov_002.pngbin0 -> 16945 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_gromov_003.pngbin0 -> 16530 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_gromov_barycenter_001.pngbin0 -> 47537 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.pngbin20684 -> 16995 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.pngbin21750 -> 18588 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_005.pngbin22971 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_006.pngbin0 -> 19258 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_008.pngbin0 -> 20440 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_classes_001.pngbin0 -> 50899 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_classes_003.pngbin0 -> 197590 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_001.pngbin0 -> 144957 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_003.pngbin0 -> 50401 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_005.pngbin0 -> 234564 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_d2_001.pngbin0 -> 131063 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_d2_003.pngbin0 -> 213055 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_d2_006.pngbin0 -> 99762 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_001.pngbin0 -> 36766 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_003.pngbin0 -> 75842 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.pngbin0 -> 165592 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.pngbin0 -> 80722 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.pngbin0 -> 541314 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.pngbin0 -> 153695 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.pngbin0 -> 37987 bytes
-rw-r--r--docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.pngbin0 -> 74267 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_2D_thumb.pngbin34799 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_classes_thumb.pngbin34581 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_color_images_thumb.pngbin52919 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_color_images_thumb.pngbin52919 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_thumb.pngbin26370 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.pngbin21175 -> 18227 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.pngbin23897 -> 22134 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.pngbin22735 -> 10935 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_conv_thumb.pngbin2894 -> 0 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.pngbin66426 -> 86417 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.pngbin19862 -> 16522 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.pngbin72940 -> 80805 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.pngbin0 -> 34183 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.pngbin0 -> 30843 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.pngbin21750 -> 3101 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.pngbin0 -> 30868 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.pngbin0 -> 51085 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.pngbin0 -> 52743 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.pngbin0 -> 58315 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.pngbin0 -> 18478 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.pngbin0 -> 64710 bytes
-rw-r--r--docs/source/auto_examples/images/thumb/sphx_glr_test_OT_2D_samples_stabilized_thumb.pngbin3101 -> 0 bytes
-rw-r--r--docs/source/auto_examples/index.rst140
-rw-r--r--docs/source/auto_examples/plot_OTDA_2D.ipynb54
-rw-r--r--docs/source/auto_examples/plot_OTDA_2D.py120
-rw-r--r--docs/source/auto_examples/plot_OTDA_2D.rst175
-rw-r--r--docs/source/auto_examples/plot_OTDA_classes.ipynb54
-rw-r--r--docs/source/auto_examples/plot_OTDA_classes.py112
-rw-r--r--docs/source/auto_examples/plot_OTDA_classes.rst190
-rw-r--r--docs/source/auto_examples/plot_OTDA_color_images.ipynb54
-rw-r--r--docs/source/auto_examples/plot_OTDA_color_images.py145
-rw-r--r--docs/source/auto_examples/plot_OTDA_color_images.rst191
-rw-r--r--docs/source/auto_examples/plot_OTDA_mapping.ipynb54
-rw-r--r--docs/source/auto_examples/plot_OTDA_mapping.py110
-rw-r--r--docs/source/auto_examples/plot_OTDA_mapping.rst186
-rw-r--r--docs/source/auto_examples/plot_OTDA_mapping_color_images.ipynb54
-rw-r--r--docs/source/auto_examples/plot_OTDA_mapping_color_images.py158
-rw-r--r--docs/source/auto_examples/plot_OTDA_mapping_color_images.rst246
-rw-r--r--docs/source/auto_examples/plot_OT_1D.ipynb76
-rw-r--r--docs/source/auto_examples/plot_OT_1D.py65
-rw-r--r--docs/source/auto_examples/plot_OT_1D.rst173
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.ipynb76
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.py75
-rw-r--r--docs/source/auto_examples/plot_OT_2D_samples.rst172
-rw-r--r--docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb76
-rw-r--r--docs/source/auto_examples/plot_OT_L1_vs_L2.py285
-rw-r--r--docs/source/auto_examples/plot_OT_L1_vs_L2.rst350
-rw-r--r--docs/source/auto_examples/plot_OT_conv.ipynb54
-rw-r--r--docs/source/auto_examples/plot_OT_conv.py200
-rw-r--r--docs/source/auto_examples/plot_OT_conv.rst241
-rw-r--r--docs/source/auto_examples/plot_WDA.ipynb94
-rw-r--r--docs/source/auto_examples/plot_WDA.py122
-rw-r--r--docs/source/auto_examples/plot_WDA.rst236
-rw-r--r--docs/source/auto_examples/plot_barycenter_1D.ipynb76
-rw-r--r--docs/source/auto_examples/plot_barycenter_1D.py136
-rw-r--r--docs/source/auto_examples/plot_barycenter_1D.rst211
-rw-r--r--docs/source/auto_examples/plot_compute_emd.ipynb76
-rw-r--r--docs/source/auto_examples/plot_compute_emd.py92
-rw-r--r--docs/source/auto_examples/plot_compute_emd.rst153
-rw-r--r--docs/source/auto_examples/plot_gromov.ipynb126
-rw-r--r--docs/source/auto_examples/plot_gromov.py93
-rw-r--r--docs/source/auto_examples/plot_gromov.rst180
-rw-r--r--docs/source/auto_examples/plot_gromov_barycenter.ipynb126
-rw-r--r--docs/source/auto_examples/plot_gromov_barycenter.py248
-rw-r--r--docs/source/auto_examples/plot_gromov_barycenter.rst324
-rw-r--r--docs/source/auto_examples/plot_optim_OTreg.ipynb94
-rw-r--r--docs/source/auto_examples/plot_optim_OTreg.py111
-rw-r--r--docs/source/auto_examples/plot_optim_OTreg.rst615
-rw-r--r--docs/source/auto_examples/plot_otda_classes.ipynb126
-rw-r--r--docs/source/auto_examples/plot_otda_classes.py150
-rw-r--r--docs/source/auto_examples/plot_otda_classes.rst258
-rw-r--r--docs/source/auto_examples/plot_otda_color_images.ipynb144
-rw-r--r--docs/source/auto_examples/plot_otda_color_images.py165
-rw-r--r--docs/source/auto_examples/plot_otda_color_images.rst257
-rw-r--r--docs/source/auto_examples/plot_otda_d2.ipynb144
-rw-r--r--docs/source/auto_examples/plot_otda_d2.py172
-rw-r--r--docs/source/auto_examples/plot_otda_d2.rst264
-rw-r--r--docs/source/auto_examples/plot_otda_mapping.ipynb126
-rw-r--r--docs/source/auto_examples/plot_otda_mapping.py125
-rw-r--r--docs/source/auto_examples/plot_otda_mapping.rst230
-rw-r--r--docs/source/auto_examples/plot_otda_mapping_colors_images.ipynb144
-rw-r--r--docs/source/auto_examples/plot_otda_mapping_colors_images.py174
-rw-r--r--docs/source/auto_examples/plot_otda_mapping_colors_images.rst305
-rw-r--r--docs/source/auto_examples/plot_otda_semi_supervised.ipynb144
-rw-r--r--docs/source/auto_examples/plot_otda_semi_supervised.py148
-rw-r--r--docs/source/auto_examples/plot_otda_semi_supervised.rst240
-rw-r--r--docs/source/auto_examples/searchindexbin0 -> 1892352 bytes
-rw-r--r--docs/source/conf.py12
-rw-r--r--docs/source/examples.rst39
-rw-r--r--docs/source/readme.rst33
171 files changed, 6936 insertions, 3925 deletions
diff --git a/docs/source/auto_examples/auto_examples_jupyter.zip b/docs/source/auto_examples/auto_examples_jupyter.zip
index 7c3de28..5a3f24c 100644
--- a/docs/source/auto_examples/auto_examples_jupyter.zip
+++ b/docs/source/auto_examples/auto_examples_jupyter.zip
Binary files differ
diff --git a/docs/source/auto_examples/auto_examples_python.zip b/docs/source/auto_examples/auto_examples_python.zip
index 97377e1..aa06bb6 100644
--- a/docs/source/auto_examples/auto_examples_python.zip
+++ b/docs/source/auto_examples/auto_examples_python.zip
Binary files differ
diff --git a/docs/source/auto_examples/demo_OT_1D_test.ipynb b/docs/source/auto_examples/demo_OT_1D_test.ipynb
deleted file mode 100644
index 87317ea..0000000
--- a/docs/source/auto_examples/demo_OT_1D_test.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\nDemo for 1D optimal transport\n\n@author: rflamary\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=gauss(n,m=n*.2,s=5) # m= mean, s= std\nb=gauss(n,m=n*.6,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\npl.plot(x,a,'b',label='Source distribution')\npl.plot(x,b,'r',label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2)\not.plot.plot1D_mat(a,b,M,'Cost matrix M')\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Sinkhorn\n\nlambd=1e-3\nGs=ot.sinkhorn(a,b,M,lambd,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')\n\n#%% Sinkhorn\n\nlambd=1e-4\nGss,log=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True)\nGss2,log2=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True,warmstart=log['warmstart'])\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')\n\n#%% Sinkhorn\n\nlambd=1e-11\nGss=ot.bregman.sinkhorn_epsilon_scaling(a,b,M,lambd,verbose=True)\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file
diff --git a/docs/source/auto_examples/demo_OT_1D_test.py b/docs/source/auto_examples/demo_OT_1D_test.py
deleted file mode 100644
index 9edc377..0000000
--- a/docs/source/auto_examples/demo_OT_1D_test.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Demo for 1D optimal transport
-
-@author: rflamary
-"""
-
-import numpy as np
-import matplotlib.pylab as pl
-import ot
-from ot.datasets import get_1D_gauss as gauss
-
-
-#%% parameters
-
-n=100 # nb bins
-
-# bin positions
-x=np.arange(n,dtype=np.float64)
-
-# Gaussian distributions
-a=gauss(n,m=n*.2,s=5) # m= mean, s= std
-b=gauss(n,m=n*.6,s=10)
-
-# loss matrix
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
-M/=M.max()
-
-#%% plot the distributions
-
-pl.figure(1)
-pl.plot(x,a,'b',label='Source distribution')
-pl.plot(x,b,'r',label='Target distribution')
-pl.legend()
-
-#%% plot distributions and loss matrix
-
-pl.figure(2)
-ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
-
-#%% EMD
-
-G0=ot.emd(a,b,M)
-
-pl.figure(3)
-ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
-
-#%% Sinkhorn
-
-lambd=1e-3
-Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
-
-pl.figure(4)
-ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
-
-#%% Sinkhorn
-
-lambd=1e-4
-Gss,log=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True)
-Gss2,log2=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True,warmstart=log['warmstart'])
-
-pl.figure(5)
-ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')
-
-#%% Sinkhorn
-
-lambd=1e-11
-Gss=ot.bregman.sinkhorn_epsilon_scaling(a,b,M,lambd,verbose=True)
-
-pl.figure(5)
-ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')
diff --git a/docs/source/auto_examples/demo_OT_1D_test.rst b/docs/source/auto_examples/demo_OT_1D_test.rst
deleted file mode 100644
index aebeb1d..0000000
--- a/docs/source/auto_examples/demo_OT_1D_test.rst
+++ /dev/null
@@ -1,99 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_demo_OT_1D_test.py:
-
-
-Demo for 1D optimal transport
-
-@author: rflamary
-
-
-
-.. code-block:: python
-
-
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- from ot.datasets import get_1D_gauss as gauss
-
-
- #%% parameters
-
- n=100 # nb bins
-
- # bin positions
- x=np.arange(n,dtype=np.float64)
-
- # Gaussian distributions
- a=gauss(n,m=n*.2,s=5) # m= mean, s= std
- b=gauss(n,m=n*.6,s=10)
-
- # loss matrix
- M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
- M/=M.max()
-
- #%% plot the distributions
-
- pl.figure(1)
- pl.plot(x,a,'b',label='Source distribution')
- pl.plot(x,b,'r',label='Target distribution')
- pl.legend()
-
- #%% plot distributions and loss matrix
-
- pl.figure(2)
- ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
-
- #%% EMD
-
- G0=ot.emd(a,b,M)
-
- pl.figure(3)
- ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
-
- #%% Sinkhorn
-
- lambd=1e-3
- Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
-
- pl.figure(4)
- ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
-
- #%% Sinkhorn
-
- lambd=1e-4
- Gss,log=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True)
- Gss2,log2=ot.bregman.sinkhorn_stabilized(a,b,M,lambd,verbose=True,log=True,warmstart=log['warmstart'])
-
- pl.figure(5)
- ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')
-
- #%% Sinkhorn
-
- lambd=1e-11
- Gss=ot.bregman.sinkhorn_epsilon_scaling(a,b,M,lambd,verbose=True)
-
- pl.figure(5)
- ot.plot.plot1D_mat(a,b,Gss,'OT matrix Sinkhorn stabilized')
-
-**Total running time of the script:** ( 0 minutes 0.000 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: demo_OT_1D_test.py <demo_OT_1D_test.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: demo_OT_1D_test.ipynb <demo_OT_1D_test.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/demo_OT_2D_sampleslarge.ipynb b/docs/source/auto_examples/demo_OT_2D_sampleslarge.ipynb
deleted file mode 100644
index 584a936..0000000
--- a/docs/source/auto_examples/demo_OT_2D_sampleslarge.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\nDemo for 2D Optimal transport between empirical distributions\n\n@author: rflamary\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=5000 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\n#pl.figure(1)\n#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n#pl.legend(loc=0)\n#pl.title('Source and traget distributions')\n#\n#pl.figure(2)\n#pl.imshow(M,interpolation='nearest')\n#pl.title('Cost matrix M')\n#\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\n#pl.figure(3)\n#pl.imshow(G0,interpolation='nearest')\n#pl.title('OT matrix G0')\n#\n#pl.figure(4)\n#ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\n#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n#pl.legend(loc=0)\n#pl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-3\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\n#pl.figure(5)\n#pl.imshow(Gs,interpolation='nearest')\n#pl.title('OT matrix sinkhorn')\n#\n#pl.figure(6)\n#ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\n#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n#pl.legend(loc=0)\n#pl.title('OT matrix Sinkhorn with samples')\n#"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file
diff --git a/docs/source/auto_examples/demo_OT_2D_sampleslarge.py b/docs/source/auto_examples/demo_OT_2D_sampleslarge.py
deleted file mode 100644
index ee3e8f7..0000000
--- a/docs/source/auto_examples/demo_OT_2D_sampleslarge.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-Demo for 2D Optimal transport between empirical distributions
-
-@author: rflamary
-"""
-
-import numpy as np
-import matplotlib.pylab as pl
-import ot
-
-#%% parameters and data generation
-
-n=5000 # nb samples
-
-mu_s=np.array([0,0])
-cov_s=np.array([[1,0],[0,1]])
-
-mu_t=np.array([4,4])
-cov_t=np.array([[1,-.8],[-.8,1]])
-
-xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
-xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)
-
-a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
-
-# loss matrix
-M=ot.dist(xs,xt)
-M/=M.max()
-
-#%% plot samples
-
-#pl.figure(1)
-#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
-#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
-#pl.legend(loc=0)
-#pl.title('Source and traget distributions')
-#
-#pl.figure(2)
-#pl.imshow(M,interpolation='nearest')
-#pl.title('Cost matrix M')
-#
-
-#%% EMD
-
-G0=ot.emd(a,b,M)
-
-#pl.figure(3)
-#pl.imshow(G0,interpolation='nearest')
-#pl.title('OT matrix G0')
-#
-#pl.figure(4)
-#ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
-#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
-#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
-#pl.legend(loc=0)
-#pl.title('OT matrix with samples')
-
-
-#%% sinkhorn
-
-# reg term
-lambd=5e-3
-
-Gs=ot.sinkhorn(a,b,M,lambd)
-
-#pl.figure(5)
-#pl.imshow(Gs,interpolation='nearest')
-#pl.title('OT matrix sinkhorn')
-#
-#pl.figure(6)
-#ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])
-#pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
-#pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
-#pl.legend(loc=0)
-#pl.title('OT matrix Sinkhorn with samples')
-#
-
diff --git a/docs/source/auto_examples/demo_OT_2D_sampleslarge.rst b/docs/source/auto_examples/demo_OT_2D_sampleslarge.rst
deleted file mode 100644
index f5dbb0d..0000000
--- a/docs/source/auto_examples/demo_OT_2D_sampleslarge.rst
+++ /dev/null
@@ -1,106 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_demo_OT_2D_sampleslarge.py:
-
-
-Demo for 2D Optimal transport between empirical distributions
-
-@author: rflamary
-
-
-
-.. code-block:: python
-
-
- import numpy as np
- import matplotlib.pylab as pl
- import ot
-
- #%% parameters and data generation
-
- n=5000 # nb samples
-
- mu_s=np.array([0,0])
- cov_s=np.array([[1,0],[0,1]])
-
- mu_t=np.array([4,4])
- cov_t=np.array([[1,-.8],[-.8,1]])
-
- xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
- xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)
-
- a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
-
- # loss matrix
- M=ot.dist(xs,xt)
- M/=M.max()
-
- #%% plot samples
-
- #pl.figure(1)
- #pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- #pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- #pl.legend(loc=0)
- #pl.title('Source and traget distributions')
- #
- #pl.figure(2)
- #pl.imshow(M,interpolation='nearest')
- #pl.title('Cost matrix M')
- #
-
- #%% EMD
-
- G0=ot.emd(a,b,M)
-
- #pl.figure(3)
- #pl.imshow(G0,interpolation='nearest')
- #pl.title('OT matrix G0')
- #
- #pl.figure(4)
- #ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
- #pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- #pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- #pl.legend(loc=0)
- #pl.title('OT matrix with samples')
-
-
- #%% sinkhorn
-
- # reg term
- lambd=5e-3
-
- Gs=ot.sinkhorn(a,b,M,lambd)
-
- #pl.figure(5)
- #pl.imshow(Gs,interpolation='nearest')
- #pl.title('OT matrix sinkhorn')
- #
- #pl.figure(6)
- #ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])
- #pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- #pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- #pl.legend(loc=0)
- #pl.title('OT matrix Sinkhorn with samples')
- #
-
-
-**Total running time of the script:** ( 0 minutes 0.000 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: demo_OT_2D_sampleslarge.py <demo_OT_2D_sampleslarge.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: demo_OT_2D_sampleslarge.ipynb <demo_OT_2D_sampleslarge.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_001.png
deleted file mode 100644
index 7de2b45..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_001.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_002.png
deleted file mode 100644
index dc34efd..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_002.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_003.png
deleted file mode 100644
index fbd72d5..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_003.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_004.png
deleted file mode 100644
index 227812d..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_2D_004.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_001.png
deleted file mode 100644
index 2bf4015..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_001.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_004.png
deleted file mode 100644
index c1fbf57..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_classes_004.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_001.png
deleted file mode 100644
index 36bc769..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_001.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_002.png
deleted file mode 100644
index 307e384..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_color_images_002.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_001.png
deleted file mode 100644
index 8c700ee..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_001.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_002.png
deleted file mode 100644
index 792b404..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_002.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_001.png
deleted file mode 100644
index 36bc769..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_001.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_002.png
deleted file mode 100644
index 008bf15..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_002.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png
index da42bc1..e11f5b9 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png
index 1f98598..fcab0bd 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png
deleted file mode 100644
index 9e893d6..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_003.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png
deleted file mode 100644
index 3bc248b..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_004.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_005.png
new file mode 100644
index 0000000..a75e649
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_007.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_007.png
new file mode 100644
index 0000000..96b42cd
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_1D_007.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
index e023ab4..2ea9ead 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
index dda21d4..cb6f1a1 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png
deleted file mode 100644
index f0967fb..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png
deleted file mode 100644
index 809c8fc..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
index 887bdde..895ff65 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
index 783c594..a056401 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png
new file mode 100644
index 0000000..285d474
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png
new file mode 100644
index 0000000..30ef388
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png
index b159a6a..6a21f35 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png
index 9f8e882..79e4710 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png
index 33058fc..4860d96 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png
index 9848bcb..6a21f35 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png
index 6616d3c..4860d96 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png
index 8575d93..4860d96 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.png
new file mode 100644
index 0000000..22dba2b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.png
new file mode 100644
index 0000000..5dbf96b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_009.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_009.png
new file mode 100644
index 0000000..e1e9ba8
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_009.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.png b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.png
new file mode 100644
index 0000000..e1e9ba8
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png b/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png
index 250155d..3524e19 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_WDA_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_WDA_003.png b/docs/source/auto_examples/images/sphx_glr_plot_WDA_003.png
new file mode 100644
index 0000000..819b974
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_WDA_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
index be71674..3454396 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png
index f62240b..3b23af5 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
index 11f08b2..3b23af5 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png
index b4e8f71..2e29ff9 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_005.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_005.png
new file mode 100644
index 0000000..eac9230
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_006.png b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_006.png
new file mode 100644
index 0000000..2e29ff9
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_barycenter_1D_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png
index 4917903..9cf84c6 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_002.png b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_002.png
deleted file mode 100644
index 7c06255..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_002.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_003.png b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_003.png
new file mode 100644
index 0000000..2da6ee7
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_004.png b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_004.png
new file mode 100644
index 0000000..d74c34a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_compute_emd_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_gromov_001.png b/docs/source/auto_examples/images/sphx_glr_plot_gromov_001.png
new file mode 100644
index 0000000..b4571fa
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_gromov_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_gromov_002.png b/docs/source/auto_examples/images/sphx_glr_plot_gromov_002.png
new file mode 100644
index 0000000..58c02d7
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_gromov_002.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_gromov_003.png b/docs/source/auto_examples/images/sphx_glr_plot_gromov_003.png
new file mode 100644
index 0000000..73a322d
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_gromov_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_gromov_barycenter_001.png b/docs/source/auto_examples/images/sphx_glr_plot_gromov_barycenter_001.png
new file mode 100644
index 0000000..715a116
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_gromov_barycenter_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png
index 7ffcc14..a75e649 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png
index 2a72060..7afdb53 100644
--- a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png
+++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_005.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_005.png
deleted file mode 100644
index e70a6de..0000000
--- a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_005.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_006.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_006.png
new file mode 100644
index 0000000..60078c1
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_008.png b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_008.png
new file mode 100644
index 0000000..8a4882a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_optim_OTreg_008.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_001.png
new file mode 100644
index 0000000..48fad93
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_003.png
new file mode 100644
index 0000000..c92d5c1
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_classes_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_001.png
new file mode 100644
index 0000000..95f882a
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_003.png
new file mode 100644
index 0000000..aa1a5d3
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_005.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_005.png
new file mode 100644
index 0000000..d219bb3
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_color_images_005.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_001.png
new file mode 100644
index 0000000..ef8cfd1
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_003.png
new file mode 100644
index 0000000..1ba5b1b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_006.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_006.png
new file mode 100644
index 0000000..d67fea1
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_d2_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_001.png
new file mode 100644
index 0000000..8da464b
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_003.png
new file mode 100644
index 0000000..fa93ee5
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.png
new file mode 100644
index 0000000..33134fc
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.png
new file mode 100644
index 0000000..42197e3
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.png
new file mode 100644
index 0000000..d9101da
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.png
new file mode 100644
index 0000000..324aee3
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.png
new file mode 100644
index 0000000..8ad6ca2
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.png
Binary files differ
diff --git a/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.png b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.png
new file mode 100644
index 0000000..b4eacb7
--- /dev/null
+++ b/docs/source/auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_2D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_2D_thumb.png
deleted file mode 100644
index d15269d..0000000
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_2D_thumb.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_classes_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_classes_thumb.png
deleted file mode 100644
index 5863d02..0000000
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_classes_thumb.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_color_images_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_color_images_thumb.png
deleted file mode 100644
index 5bb43c4..0000000
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_color_images_thumb.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_color_images_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_color_images_thumb.png
deleted file mode 100644
index 5bb43c4..0000000
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_color_images_thumb.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_thumb.png
deleted file mode 100644
index c3d9a65..0000000
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_thumb.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png
index 15c9825..63ff40c 100644
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_1D_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
index bac78f0..22281f4 100644
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_2D_samples_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png
index c67e8aa..95588f5 100644
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_conv_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_conv_thumb.png
deleted file mode 100644
index 3015582..0000000
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_OT_conv_thumb.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png
index 84759e8..2316fcc 100644
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png
index 86ff19f..5c17671 100644
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png
index 67d2ca1..68cbdf7 100644
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_compute_emd_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.png
new file mode 100644
index 0000000..85a94ff
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png
new file mode 100644
index 0000000..26b0b2f
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
index 2a72060..cbc8e0f 100644
--- a/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.png
new file mode 100644
index 0000000..561c5bb
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.png
new file mode 100644
index 0000000..a919055
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.png
new file mode 100644
index 0000000..bd32092
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.png
new file mode 100644
index 0000000..f7fd217
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.png
new file mode 100644
index 0000000..37d99bd
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.png
new file mode 100644
index 0000000..e1b5863
--- /dev/null
+++ b/docs/source/auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.png
Binary files differ
diff --git a/docs/source/auto_examples/images/thumb/sphx_glr_test_OT_2D_samples_stabilized_thumb.png b/docs/source/auto_examples/images/thumb/sphx_glr_test_OT_2D_samples_stabilized_thumb.png
deleted file mode 100644
index cbc8e0f..0000000
--- a/docs/source/auto_examples/images/thumb/sphx_glr_test_OT_2D_samples_stabilized_thumb.png
+++ /dev/null
Binary files differ
diff --git a/docs/source/auto_examples/index.rst b/docs/source/auto_examples/index.rst
index 1695300..eb54ca8 100644
--- a/docs/source/auto_examples/index.rst
+++ b/docs/source/auto_examples/index.rst
@@ -1,9 +1,11 @@
POT Examples
============
+This is a gallery of all the POT example files.
+
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="@author: rflamary ">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of EMD and Sinkhorn transport plans and their visualiz...">
.. only:: html
@@ -23,13 +25,13 @@ POT Examples
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="@author: rflamary ">
+ <div class="sphx-glr-thumbcontainer" tooltip="Illustrates the use of the generic solver for regularized OT with user-designed regularization ...">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
- :ref:`sphx_glr_auto_examples_plot_WDA.py`
+ :ref:`sphx_glr_auto_examples_plot_optim_OTreg.py`
.. raw:: html
@@ -39,17 +41,17 @@ POT Examples
.. toctree::
:hidden:
- /auto_examples/plot_WDA
+ /auto_examples/plot_optim_OTreg
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip=" ">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the Gromov-Wassertsein distance computation in POT....">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_optim_OTreg_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_gromov_thumb.png
- :ref:`sphx_glr_auto_examples_plot_optim_OTreg.py`
+ :ref:`sphx_glr_auto_examples_plot_gromov.py`
.. raw:: html
@@ -59,11 +61,11 @@ POT Examples
.. toctree::
:hidden:
- /auto_examples/plot_optim_OTreg
+ /auto_examples/plot_gromov
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="@author: rflamary ">
+ <div class="sphx-glr-thumbcontainer" tooltip="Illustration of 2D optimal transport between discributions that are weighted sum of diracs. The...">
.. only:: html
@@ -83,7 +85,7 @@ POT Examples
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="@author: rflamary ">
+ <div class="sphx-glr-thumbcontainer" tooltip="Shows how to compute multiple EMD and Sinkhorn with two differnt ground metrics and plot their ...">
.. only:: html
@@ -103,13 +105,13 @@ POT Examples
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optima...">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrate the use of WDA as proposed in [11].">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_color_images_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_WDA_thumb.png
- :ref:`sphx_glr_auto_examples_plot_OTDA_color_images.py`
+ :ref:`sphx_glr_auto_examples_plot_WDA.py`
.. raw:: html
@@ -119,17 +121,17 @@ POT Examples
.. toctree::
:hidden:
- /auto_examples/plot_OTDA_color_images
+ /auto_examples/plot_WDA
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example presents a way of transferring colors between two image with Optimal Transport as ...">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_classes_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_color_images_thumb.png
- :ref:`sphx_glr_auto_examples_plot_OTDA_classes.py`
+ :ref:`sphx_glr_auto_examples_plot_otda_color_images.py`
.. raw:: html
@@ -139,17 +141,17 @@ POT Examples
.. toctree::
:hidden:
- /auto_examples/plot_OTDA_classes
+ /auto_examples/plot_otda_color_images
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example illustrates the computation of regularized Wassersyein Barycenter as proposed in [...">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_2D_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png
- :ref:`sphx_glr_auto_examples_plot_OTDA_2D.py`
+ :ref:`sphx_glr_auto_examples_plot_barycenter_1D.py`
.. raw:: html
@@ -159,17 +161,17 @@ POT Examples
.. toctree::
:hidden:
- /auto_examples/plot_OTDA_2D
+ /auto_examples/plot_barycenter_1D
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="Stole the figure idea from Fig. 1 and 2 in https://arxiv.org/pdf/1706.07650.pdf">
+ <div class="sphx-glr-thumbcontainer" tooltip="OT for domain adaptation with image color adaptation [6] with mapping estimation [8].">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_mapping_colors_images_thumb.png
- :ref:`sphx_glr_auto_examples_plot_OT_L1_vs_L2.py`
+ :ref:`sphx_glr_auto_examples_plot_otda_mapping_colors_images.py`
.. raw:: html
@@ -179,17 +181,17 @@ POT Examples
.. toctree::
:hidden:
- /auto_examples/plot_OT_L1_vs_L2
+ /auto_examples/plot_otda_mapping_colors_images
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip=" @author: rflamary ">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example presents how to use MappingTransport to estimate at the same time both the couplin...">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_barycenter_1D_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_mapping_thumb.png
- :ref:`sphx_glr_auto_examples_plot_barycenter_1D.py`
+ :ref:`sphx_glr_auto_examples_plot_otda_mapping.py`
.. raw:: html
@@ -199,17 +201,17 @@ POT Examples
.. toctree::
:hidden:
- /auto_examples/plot_barycenter_1D
+ /auto_examples/plot_otda_mapping
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete op...">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example introduces a semi supervised domain adaptation in a 2D setting. It explicits the p...">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_color_images_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_semi_supervised_thumb.png
- :ref:`sphx_glr_auto_examples_plot_OTDA_mapping_color_images.py`
+ :ref:`sphx_glr_auto_examples_plot_otda_semi_supervised.py`
.. raw:: html
@@ -219,17 +221,77 @@ POT Examples
.. toctree::
:hidden:
- /auto_examples/plot_OTDA_mapping_color_images
+ /auto_examples/plot_otda_semi_supervised
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example introduces a domain adaptation in a 2D setting and the 4 OTDA approaches currently...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_classes_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_otda_classes.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_otda_classes
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="This example introduces a domain adaptation in a 2D setting. It explicits the problem of domain...">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_otda_d2_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_otda_d2.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_otda_d2
+
+.. raw:: html
+
+ <div class="sphx-glr-thumbcontainer" tooltip="2D OT on empirical distributio with different gound metric.">
+
+.. only:: html
+
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OT_L1_vs_L2_thumb.png
+
+ :ref:`sphx_glr_auto_examples_plot_OT_L1_vs_L2.py`
+
+.. raw:: html
+
+ </div>
+
+
+.. toctree::
+ :hidden:
+
+ /auto_examples/plot_OT_L1_vs_L2
.. raw:: html
- <div class="sphx-glr-thumbcontainer" tooltip="[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal ...">
+ <div class="sphx-glr-thumbcontainer" tooltip="This example is designed to show how to use the Gromov-Wasserstein distance computation in POT....">
.. only:: html
- .. figure:: /auto_examples/images/thumb/sphx_glr_plot_OTDA_mapping_thumb.png
+ .. figure:: /auto_examples/images/thumb/sphx_glr_plot_gromov_barycenter_thumb.png
- :ref:`sphx_glr_auto_examples_plot_OTDA_mapping.py`
+ :ref:`sphx_glr_auto_examples_plot_gromov_barycenter.py`
.. raw:: html
@@ -239,7 +301,7 @@ POT Examples
.. toctree::
:hidden:
- /auto_examples/plot_OTDA_mapping
+ /auto_examples/plot_gromov_barycenter
.. raw:: html
<div style='clear:both'></div>
diff --git a/docs/source/auto_examples/plot_OTDA_2D.ipynb b/docs/source/auto_examples/plot_OTDA_2D.ipynb
deleted file mode 100644
index 2ffb256..0000000
--- a/docs/source/auto_examples/plot_OTDA_2D.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\n# OT for empirical distributions\n\n\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% parameters\n\nn=150 # nb bins\n\nxs,ys=ot.datasets.get_data_classif('3gauss',n)\nxt,yt=ot.datasets.get_data_classif('3gauss2',n)\n\na,b = ot.unif(n),ot.unif(n)\n# loss matrix\nM=ot.dist(xs,xt)\n#M/=M.max()\n\n#%% plot samples\n\npl.figure(1)\n\npl.subplot(2,2,1)\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.legend(loc=0)\npl.title('Source distributions')\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\npl.legend(loc=0)\npl.title('target distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% OT estimation\n\n# EMD\nG0=ot.emd(a,b,M)\n\n# sinkhorn\nlambd=1e-1\nGs=ot.sinkhorn(a,b,M,lambd)\n\n\n# Group lasso regularization\nreg=1e-1\neta=1e0\nGg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta)\n\n\n#%% visu matrices\n\npl.figure(3)\n\npl.subplot(2,3,1)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix ')\n\npl.subplot(2,3,2)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix Sinkhorn')\n\npl.subplot(2,3,3)\npl.imshow(Gg,interpolation='nearest')\npl.title('OT matrix Group lasso')\n\npl.subplot(2,3,4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\n\npl.subplot(2,3,5)\not.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\npl.subplot(2,3,6)\not.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1])\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\n#%% sample interpolation\n\nxst0=n*G0.dot(xt)\nxsts=n*Gs.dot(xt)\nxstg=n*Gg.dot(xt)\n\npl.figure(4)\npl.subplot(2,3,1)\n\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples')\npl.legend(loc=0)\n\npl.subplot(2,3,2)\n\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Sinkhorn')\n\npl.subplot(2,3,3)\n\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)\npl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Grouplasso')"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OTDA_2D.py b/docs/source/auto_examples/plot_OTDA_2D.py
deleted file mode 100644
index a1fb804..0000000
--- a/docs/source/auto_examples/plot_OTDA_2D.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-==============================
-OT for empirical distributions
-==============================
-
-"""
-
-import numpy as np
-import matplotlib.pylab as pl
-import ot
-
-
-
-#%% parameters
-
-n=150 # nb bins
-
-xs,ys=ot.datasets.get_data_classif('3gauss',n)
-xt,yt=ot.datasets.get_data_classif('3gauss2',n)
-
-a,b = ot.unif(n),ot.unif(n)
-# loss matrix
-M=ot.dist(xs,xt)
-#M/=M.max()
-
-#%% plot samples
-
-pl.figure(1)
-
-pl.subplot(2,2,1)
-pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
-pl.legend(loc=0)
-pl.title('Source distributions')
-
-pl.subplot(2,2,2)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-pl.legend(loc=0)
-pl.title('target distributions')
-
-pl.figure(2)
-pl.imshow(M,interpolation='nearest')
-pl.title('Cost matrix M')
-
-
-#%% OT estimation
-
-# EMD
-G0=ot.emd(a,b,M)
-
-# sinkhorn
-lambd=1e-1
-Gs=ot.sinkhorn(a,b,M,lambd)
-
-
-# Group lasso regularization
-reg=1e-1
-eta=1e0
-Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta)
-
-
-#%% visu matrices
-
-pl.figure(3)
-
-pl.subplot(2,3,1)
-pl.imshow(G0,interpolation='nearest')
-pl.title('OT matrix ')
-
-pl.subplot(2,3,2)
-pl.imshow(Gs,interpolation='nearest')
-pl.title('OT matrix Sinkhorn')
-
-pl.subplot(2,3,3)
-pl.imshow(Gg,interpolation='nearest')
-pl.title('OT matrix Group lasso')
-
-pl.subplot(2,3,4)
-ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
-pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-
-
-pl.subplot(2,3,5)
-ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1])
-pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-
-pl.subplot(2,3,6)
-ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1])
-pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-
-#%% sample interpolation
-
-xst0=n*G0.dot(xt)
-xsts=n*Gs.dot(xt)
-xstg=n*Gg.dot(xt)
-
-pl.figure(4)
-pl.subplot(2,3,1)
-
-
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
-pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
-pl.title('Interp samples')
-pl.legend(loc=0)
-
-pl.subplot(2,3,2)
-
-
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
-pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
-pl.title('Interp samples Sinkhorn')
-
-pl.subplot(2,3,3)
-
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
-pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
-pl.title('Interp samples Grouplasso') \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OTDA_2D.rst b/docs/source/auto_examples/plot_OTDA_2D.rst
deleted file mode 100644
index b535bb0..0000000
--- a/docs/source/auto_examples/plot_OTDA_2D.rst
+++ /dev/null
@@ -1,175 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_plot_OTDA_2D.py:
-
-
-==============================
-OT for empirical distributions
-==============================
-
-
-
-
-
-.. rst-class:: sphx-glr-horizontal
-
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_2D_001.png
- :scale: 47
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_2D_002.png
- :scale: 47
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_2D_003.png
- :scale: 47
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_2D_004.png
- :scale: 47
-
-
-
-
-
-.. code-block:: python
-
-
- import numpy as np
- import matplotlib.pylab as pl
- import ot
-
-
-
- #%% parameters
-
- n=150 # nb bins
-
- xs,ys=ot.datasets.get_data_classif('3gauss',n)
- xt,yt=ot.datasets.get_data_classif('3gauss2',n)
-
- a,b = ot.unif(n),ot.unif(n)
- # loss matrix
- M=ot.dist(xs,xt)
- #M/=M.max()
-
- #%% plot samples
-
- pl.figure(1)
-
- pl.subplot(2,2,1)
- pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
- pl.legend(loc=0)
- pl.title('Source distributions')
-
- pl.subplot(2,2,2)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
- pl.legend(loc=0)
- pl.title('target distributions')
-
- pl.figure(2)
- pl.imshow(M,interpolation='nearest')
- pl.title('Cost matrix M')
-
-
- #%% OT estimation
-
- # EMD
- G0=ot.emd(a,b,M)
-
- # sinkhorn
- lambd=1e-1
- Gs=ot.sinkhorn(a,b,M,lambd)
-
-
- # Group lasso regularization
- reg=1e-1
- eta=1e0
- Gg=ot.da.sinkhorn_lpl1_mm(a,ys.astype(np.int),b,M,reg,eta)
-
-
- #%% visu matrices
-
- pl.figure(3)
-
- pl.subplot(2,3,1)
- pl.imshow(G0,interpolation='nearest')
- pl.title('OT matrix ')
-
- pl.subplot(2,3,2)
- pl.imshow(Gs,interpolation='nearest')
- pl.title('OT matrix Sinkhorn')
-
- pl.subplot(2,3,3)
- pl.imshow(Gg,interpolation='nearest')
- pl.title('OT matrix Group lasso')
-
- pl.subplot(2,3,4)
- ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
- pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-
-
- pl.subplot(2,3,5)
- ot.plot.plot2D_samples_mat(xs,xt,Gs,c=[.5,.5,1])
- pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-
- pl.subplot(2,3,6)
- ot.plot.plot2D_samples_mat(xs,xt,Gg,c=[.5,.5,1])
- pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-
- #%% sample interpolation
-
- xst0=n*G0.dot(xt)
- xsts=n*Gs.dot(xt)
- xstg=n*Gg.dot(xt)
-
- pl.figure(4)
- pl.subplot(2,3,1)
-
-
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
- pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples')
- pl.legend(loc=0)
-
- pl.subplot(2,3,2)
-
-
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
- pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples Sinkhorn')
-
- pl.subplot(2,3,3)
-
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.5)
- pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples Grouplasso')
-**Total running time of the script:** ( 0 minutes 17.372 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: plot_OTDA_2D.py <plot_OTDA_2D.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: plot_OTDA_2D.ipynb <plot_OTDA_2D.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_OTDA_classes.ipynb b/docs/source/auto_examples/plot_OTDA_classes.ipynb
deleted file mode 100644
index d9fcb87..0000000
--- a/docs/source/auto_examples/plot_OTDA_classes.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\n# OT for domain adaptation\n\n\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import matplotlib.pylab as pl\nimport ot\n\n\n\n\n#%% parameters\n\nn=150 # nb samples in source and target datasets\n\nxs,ys=ot.datasets.get_data_classif('3gauss',n)\nxt,yt=ot.datasets.get_data_classif('3gauss2',n)\n\n\n\n\n#%% plot samples\n\npl.figure(1)\n\npl.subplot(2,2,1)\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.legend(loc=0)\npl.title('Source distributions')\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\npl.legend(loc=0)\npl.title('target distributions')\n\n\n#%% OT estimation\n\n# LP problem\nda_emd=ot.da.OTDA() # init class\nda_emd.fit(xs,xt) # fit distributions\nxst0=da_emd.interp() # interpolation of source samples\n\n\n# sinkhorn regularization\nlambd=1e-1\nda_entrop=ot.da.OTDA_sinkhorn()\nda_entrop.fit(xs,xt,reg=lambd)\nxsts=da_entrop.interp()\n\n# non-convex Group lasso regularization\nreg=1e-1\neta=1e0\nda_lpl1=ot.da.OTDA_lpl1()\nda_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)\nxstg=da_lpl1.interp()\n\n\n# True Group lasso regularization\nreg=1e-1\neta=2e0\nda_l1l2=ot.da.OTDA_l1l2()\nda_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)\nxstgl=da_l1l2.interp()\n\n\n#%% plot interpolated source samples\npl.figure(4,(15,8))\n\nparam_img={'interpolation':'nearest','cmap':'jet'}\n\npl.subplot(2,4,1)\npl.imshow(da_emd.G,**param_img)\npl.title('OT matrix')\n\n\npl.subplot(2,4,2)\npl.imshow(da_entrop.G,**param_img)\npl.title('OT matrix sinkhorn')\n\npl.subplot(2,4,3)\npl.imshow(da_lpl1.G,**param_img)\npl.title('OT matrix non-convex Group Lasso')\n\npl.subplot(2,4,4)\npl.imshow(da_l1l2.G,**param_img)\npl.title('OT matrix Group Lasso')\n\n\npl.subplot(2,4,5)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples')\npl.legend(loc=0)\n\npl.subplot(2,4,6)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)\npl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Sinkhorn')\n\npl.subplot(2,4,7)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)\npl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples non-convex Group Lasso')\n\npl.subplot(2,4,8)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)\npl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)\npl.title('Interp samples Group Lasso')"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OTDA_classes.py b/docs/source/auto_examples/plot_OTDA_classes.py
deleted file mode 100644
index 089b45b..0000000
--- a/docs/source/auto_examples/plot_OTDA_classes.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-========================
-OT for domain adaptation
-========================
-
-"""
-
-import matplotlib.pylab as pl
-import ot
-
-
-
-
-#%% parameters
-
-n=150 # nb samples in source and target datasets
-
-xs,ys=ot.datasets.get_data_classif('3gauss',n)
-xt,yt=ot.datasets.get_data_classif('3gauss2',n)
-
-
-
-
-#%% plot samples
-
-pl.figure(1)
-
-pl.subplot(2,2,1)
-pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
-pl.legend(loc=0)
-pl.title('Source distributions')
-
-pl.subplot(2,2,2)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-pl.legend(loc=0)
-pl.title('target distributions')
-
-
-#%% OT estimation
-
-# LP problem
-da_emd=ot.da.OTDA() # init class
-da_emd.fit(xs,xt) # fit distributions
-xst0=da_emd.interp() # interpolation of source samples
-
-
-# sinkhorn regularization
-lambd=1e-1
-da_entrop=ot.da.OTDA_sinkhorn()
-da_entrop.fit(xs,xt,reg=lambd)
-xsts=da_entrop.interp()
-
-# non-convex Group lasso regularization
-reg=1e-1
-eta=1e0
-da_lpl1=ot.da.OTDA_lpl1()
-da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)
-xstg=da_lpl1.interp()
-
-
-# True Group lasso regularization
-reg=1e-1
-eta=2e0
-da_l1l2=ot.da.OTDA_l1l2()
-da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)
-xstgl=da_l1l2.interp()
-
-
-#%% plot interpolated source samples
-pl.figure(4,(15,8))
-
-param_img={'interpolation':'nearest','cmap':'jet'}
-
-pl.subplot(2,4,1)
-pl.imshow(da_emd.G,**param_img)
-pl.title('OT matrix')
-
-
-pl.subplot(2,4,2)
-pl.imshow(da_entrop.G,**param_img)
-pl.title('OT matrix sinkhorn')
-
-pl.subplot(2,4,3)
-pl.imshow(da_lpl1.G,**param_img)
-pl.title('OT matrix non-convex Group Lasso')
-
-pl.subplot(2,4,4)
-pl.imshow(da_l1l2.G,**param_img)
-pl.title('OT matrix Group Lasso')
-
-
-pl.subplot(2,4,5)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
-pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
-pl.title('Interp samples')
-pl.legend(loc=0)
-
-pl.subplot(2,4,6)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
-pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
-pl.title('Interp samples Sinkhorn')
-
-pl.subplot(2,4,7)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
-pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
-pl.title('Interp samples non-convex Group Lasso')
-
-pl.subplot(2,4,8)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
-pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)
-pl.title('Interp samples Group Lasso') \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OTDA_classes.rst b/docs/source/auto_examples/plot_OTDA_classes.rst
deleted file mode 100644
index 097e9fc..0000000
--- a/docs/source/auto_examples/plot_OTDA_classes.rst
+++ /dev/null
@@ -1,190 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_plot_OTDA_classes.py:
-
-
-========================
-OT for domain adaptation
-========================
-
-
-
-
-
-.. rst-class:: sphx-glr-horizontal
-
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_001.png
- :scale: 47
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_classes_004.png
- :scale: 47
-
-
-.. rst-class:: sphx-glr-script-out
-
- Out::
-
- It. |Loss |Delta loss
- --------------------------------
- 0|9.171271e+00|0.000000e+00
- 1|2.133783e+00|-3.298127e+00
- 2|1.895941e+00|-1.254484e-01
- 3|1.844628e+00|-2.781709e-02
- 4|1.824983e+00|-1.076467e-02
- 5|1.815453e+00|-5.249337e-03
- 6|1.808104e+00|-4.064733e-03
- 7|1.803558e+00|-2.520475e-03
- 8|1.801061e+00|-1.386155e-03
- 9|1.799391e+00|-9.279565e-04
- 10|1.797176e+00|-1.232778e-03
- 11|1.795465e+00|-9.529479e-04
- 12|1.795316e+00|-8.322362e-05
- 13|1.794523e+00|-4.418932e-04
- 14|1.794444e+00|-4.390599e-05
- 15|1.794395e+00|-2.710318e-05
- 16|1.793713e+00|-3.804028e-04
- 17|1.793110e+00|-3.359479e-04
- 18|1.792829e+00|-1.569563e-04
- 19|1.792621e+00|-1.159469e-04
- It. |Loss |Delta loss
- --------------------------------
- 20|1.791334e+00|-7.187689e-04
-
-
-
-
-|
-
-
-.. code-block:: python
-
-
- import matplotlib.pylab as pl
- import ot
-
-
-
-
- #%% parameters
-
- n=150 # nb samples in source and target datasets
-
- xs,ys=ot.datasets.get_data_classif('3gauss',n)
- xt,yt=ot.datasets.get_data_classif('3gauss2',n)
-
-
-
-
- #%% plot samples
-
- pl.figure(1)
-
- pl.subplot(2,2,1)
- pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
- pl.legend(loc=0)
- pl.title('Source distributions')
-
- pl.subplot(2,2,2)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
- pl.legend(loc=0)
- pl.title('target distributions')
-
-
- #%% OT estimation
-
- # LP problem
- da_emd=ot.da.OTDA() # init class
- da_emd.fit(xs,xt) # fit distributions
- xst0=da_emd.interp() # interpolation of source samples
-
-
- # sinkhorn regularization
- lambd=1e-1
- da_entrop=ot.da.OTDA_sinkhorn()
- da_entrop.fit(xs,xt,reg=lambd)
- xsts=da_entrop.interp()
-
- # non-convex Group lasso regularization
- reg=1e-1
- eta=1e0
- da_lpl1=ot.da.OTDA_lpl1()
- da_lpl1.fit(xs,ys,xt,reg=reg,eta=eta)
- xstg=da_lpl1.interp()
-
-
- # True Group lasso regularization
- reg=1e-1
- eta=2e0
- da_l1l2=ot.da.OTDA_l1l2()
- da_l1l2.fit(xs,ys,xt,reg=reg,eta=eta,numItermax=20,verbose=True)
- xstgl=da_l1l2.interp()
-
-
- #%% plot interpolated source samples
- pl.figure(4,(15,8))
-
- param_img={'interpolation':'nearest','cmap':'jet'}
-
- pl.subplot(2,4,1)
- pl.imshow(da_emd.G,**param_img)
- pl.title('OT matrix')
-
-
- pl.subplot(2,4,2)
- pl.imshow(da_entrop.G,**param_img)
- pl.title('OT matrix sinkhorn')
-
- pl.subplot(2,4,3)
- pl.imshow(da_lpl1.G,**param_img)
- pl.title('OT matrix non-convex Group Lasso')
-
- pl.subplot(2,4,4)
- pl.imshow(da_l1l2.G,**param_img)
- pl.title('OT matrix Group Lasso')
-
-
- pl.subplot(2,4,5)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
- pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples')
- pl.legend(loc=0)
-
- pl.subplot(2,4,6)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
- pl.scatter(xsts[:,0],xsts[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples Sinkhorn')
-
- pl.subplot(2,4,7)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
- pl.scatter(xstg[:,0],xstg[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples non-convex Group Lasso')
-
- pl.subplot(2,4,8)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=0.3)
- pl.scatter(xstgl[:,0],xstgl[:,1],c=ys,marker='+',label='Transp samples',s=30)
- pl.title('Interp samples Group Lasso')
-**Total running time of the script:** ( 0 minutes 2.225 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: plot_OTDA_classes.py <plot_OTDA_classes.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: plot_OTDA_classes.ipynb <plot_OTDA_classes.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_OTDA_color_images.ipynb b/docs/source/auto_examples/plot_OTDA_color_images.ipynb
deleted file mode 100644
index d174828..0000000
--- a/docs/source/auto_examples/plot_OTDA_color_images.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\n========================================================\nOT for domain adaptation with image color adaptation [6]\n========================================================\n\n[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import numpy as np\nimport scipy.ndimage as spi\nimport matplotlib.pylab as pl\nimport ot\n\n\n#%% Loading images\n\nI1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256\nI2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256\n\n#%% Plot images\n\npl.figure(1)\n\npl.subplot(1,2,1)\npl.imshow(I1)\npl.title('Image 1')\n\npl.subplot(1,2,2)\npl.imshow(I2)\npl.title('Image 2')\n\npl.show()\n\n#%% Image conversion and dataset generation\n\ndef im2mat(I):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))\n\ndef mat2im(X,shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\nX1=im2mat(I1)\nX2=im2mat(I2)\n\n# training samples\nnb=1000\nidx1=np.random.randint(X1.shape[0],size=(nb,))\nidx2=np.random.randint(X2.shape[0],size=(nb,))\n\nxs=X1[idx1,:]\nxt=X2[idx2,:]\n\n#%% Plot image distributions\n\n\npl.figure(2,(10,5))\n\npl.subplot(1,2,1)\npl.scatter(xs[:,0],xs[:,2],c=xs)\npl.axis([0,1,0,1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 1')\n\npl.subplot(1,2,2)\n#pl.imshow(I2)\npl.scatter(xt[:,0],xt[:,2],c=xt)\npl.axis([0,1,0,1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 2')\n\npl.show()\n\n\n\n#%% domain adaptation between images\n\n# LP problem\nda_emd=ot.da.OTDA() # init class\nda_emd.fit(xs,xt) # fit distributions\n\n\n# sinkhorn regularization\nlambd=1e-1\nda_entrop=ot.da.OTDA_sinkhorn()\nda_entrop.fit(xs,xt,reg=lambd)\n\n\n\n#%% prediction between images (using out of sample prediction as in [6])\n\nX1t=da_emd.predict(X1)\nX2t=da_emd.predict(X2,-1)\n\n\nX1te=da_entrop.predict(X1)\nX2te=da_entrop.predict(X2,-1)\n\n\ndef minmax(I):\n return np.minimum(np.maximum(I,0),1)\n\nI1t=minmax(mat2im(X1t,I1.shape))\nI2t=minmax(mat2im(X2t,I2.shape))\n\nI1te=minmax(mat2im(X1te,I1.shape))\nI2te=minmax(mat2im(X2te,I2.shape))\n\n#%% plot all images\n\npl.figure(2,(10,8))\n\npl.subplot(2,3,1)\n\npl.imshow(I1)\npl.title('Image 1')\n\npl.subplot(2,3,2)\npl.imshow(I1t)\npl.title('Image 1 Adapt')\n\n\npl.subplot(2,3,3)\npl.imshow(I1te)\npl.title('Image 1 Adapt (reg)')\n\npl.subplot(2,3,4)\n\npl.imshow(I2)\npl.title('Image 2')\n\npl.subplot(2,3,5)\npl.imshow(I2t)\npl.title('Image 2 Adapt')\n\n\npl.subplot(2,3,6)\npl.imshow(I2te)\npl.title('Image 2 Adapt (reg)')\n\npl.show()"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OTDA_color_images.py b/docs/source/auto_examples/plot_OTDA_color_images.py
deleted file mode 100644
index 68eee44..0000000
--- a/docs/source/auto_examples/plot_OTDA_color_images.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-========================================================
-OT for domain adaptation with image color adaptation [6]
-========================================================
-
-[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
-"""
-
-import numpy as np
-import scipy.ndimage as spi
-import matplotlib.pylab as pl
-import ot
-
-
-#%% Loading images
-
-I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256
-I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256
-
-#%% Plot images
-
-pl.figure(1)
-
-pl.subplot(1,2,1)
-pl.imshow(I1)
-pl.title('Image 1')
-
-pl.subplot(1,2,2)
-pl.imshow(I2)
-pl.title('Image 2')
-
-pl.show()
-
-#%% Image conversion and dataset generation
-
-def im2mat(I):
- """Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))
-
-def mat2im(X,shape):
- """Converts back a matrix to an image"""
- return X.reshape(shape)
-
-X1=im2mat(I1)
-X2=im2mat(I2)
-
-# training samples
-nb=1000
-idx1=np.random.randint(X1.shape[0],size=(nb,))
-idx2=np.random.randint(X2.shape[0],size=(nb,))
-
-xs=X1[idx1,:]
-xt=X2[idx2,:]
-
-#%% Plot image distributions
-
-
-pl.figure(2,(10,5))
-
-pl.subplot(1,2,1)
-pl.scatter(xs[:,0],xs[:,2],c=xs)
-pl.axis([0,1,0,1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
-
-pl.subplot(1,2,2)
-#pl.imshow(I2)
-pl.scatter(xt[:,0],xt[:,2],c=xt)
-pl.axis([0,1,0,1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-
-pl.show()
-
-
-
-#%% domain adaptation between images
-
-# LP problem
-da_emd=ot.da.OTDA() # init class
-da_emd.fit(xs,xt) # fit distributions
-
-
-# sinkhorn regularization
-lambd=1e-1
-da_entrop=ot.da.OTDA_sinkhorn()
-da_entrop.fit(xs,xt,reg=lambd)
-
-
-
-#%% prediction between images (using out of sample prediction as in [6])
-
-X1t=da_emd.predict(X1)
-X2t=da_emd.predict(X2,-1)
-
-
-X1te=da_entrop.predict(X1)
-X2te=da_entrop.predict(X2,-1)
-
-
-def minmax(I):
- return np.minimum(np.maximum(I,0),1)
-
-I1t=minmax(mat2im(X1t,I1.shape))
-I2t=minmax(mat2im(X2t,I2.shape))
-
-I1te=minmax(mat2im(X1te,I1.shape))
-I2te=minmax(mat2im(X2te,I2.shape))
-
-#%% plot all images
-
-pl.figure(2,(10,8))
-
-pl.subplot(2,3,1)
-
-pl.imshow(I1)
-pl.title('Image 1')
-
-pl.subplot(2,3,2)
-pl.imshow(I1t)
-pl.title('Image 1 Adapt')
-
-
-pl.subplot(2,3,3)
-pl.imshow(I1te)
-pl.title('Image 1 Adapt (reg)')
-
-pl.subplot(2,3,4)
-
-pl.imshow(I2)
-pl.title('Image 2')
-
-pl.subplot(2,3,5)
-pl.imshow(I2t)
-pl.title('Image 2 Adapt')
-
-
-pl.subplot(2,3,6)
-pl.imshow(I2te)
-pl.title('Image 2 Adapt (reg)')
-
-pl.show()
diff --git a/docs/source/auto_examples/plot_OTDA_color_images.rst b/docs/source/auto_examples/plot_OTDA_color_images.rst
deleted file mode 100644
index a982a90..0000000
--- a/docs/source/auto_examples/plot_OTDA_color_images.rst
+++ /dev/null
@@ -1,191 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_plot_OTDA_color_images.py:
-
-
-========================================================
-OT for domain adaptation with image color adaptation [6]
-========================================================
-
-[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
-
-
-
-
-.. rst-class:: sphx-glr-horizontal
-
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_color_images_001.png
- :scale: 47
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_color_images_002.png
- :scale: 47
-
-
-
-
-
-.. code-block:: python
-
-
- import numpy as np
- import scipy.ndimage as spi
- import matplotlib.pylab as pl
- import ot
-
-
- #%% Loading images
-
- I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256
- I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256
-
- #%% Plot images
-
- pl.figure(1)
-
- pl.subplot(1,2,1)
- pl.imshow(I1)
- pl.title('Image 1')
-
- pl.subplot(1,2,2)
- pl.imshow(I2)
- pl.title('Image 2')
-
- pl.show()
-
- #%% Image conversion and dataset generation
-
- def im2mat(I):
- """Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))
-
- def mat2im(X,shape):
- """Converts back a matrix to an image"""
- return X.reshape(shape)
-
- X1=im2mat(I1)
- X2=im2mat(I2)
-
- # training samples
- nb=1000
- idx1=np.random.randint(X1.shape[0],size=(nb,))
- idx2=np.random.randint(X2.shape[0],size=(nb,))
-
- xs=X1[idx1,:]
- xt=X2[idx2,:]
-
- #%% Plot image distributions
-
-
- pl.figure(2,(10,5))
-
- pl.subplot(1,2,1)
- pl.scatter(xs[:,0],xs[:,2],c=xs)
- pl.axis([0,1,0,1])
- pl.xlabel('Red')
- pl.ylabel('Blue')
- pl.title('Image 1')
-
- pl.subplot(1,2,2)
- #pl.imshow(I2)
- pl.scatter(xt[:,0],xt[:,2],c=xt)
- pl.axis([0,1,0,1])
- pl.xlabel('Red')
- pl.ylabel('Blue')
- pl.title('Image 2')
-
- pl.show()
-
-
-
- #%% domain adaptation between images
-
- # LP problem
- da_emd=ot.da.OTDA() # init class
- da_emd.fit(xs,xt) # fit distributions
-
-
- # sinkhorn regularization
- lambd=1e-1
- da_entrop=ot.da.OTDA_sinkhorn()
- da_entrop.fit(xs,xt,reg=lambd)
-
-
-
- #%% prediction between images (using out of sample prediction as in [6])
-
- X1t=da_emd.predict(X1)
- X2t=da_emd.predict(X2,-1)
-
-
- X1te=da_entrop.predict(X1)
- X2te=da_entrop.predict(X2,-1)
-
-
- def minmax(I):
- return np.minimum(np.maximum(I,0),1)
-
- I1t=minmax(mat2im(X1t,I1.shape))
- I2t=minmax(mat2im(X2t,I2.shape))
-
- I1te=minmax(mat2im(X1te,I1.shape))
- I2te=minmax(mat2im(X2te,I2.shape))
-
- #%% plot all images
-
- pl.figure(2,(10,8))
-
- pl.subplot(2,3,1)
-
- pl.imshow(I1)
- pl.title('Image 1')
-
- pl.subplot(2,3,2)
- pl.imshow(I1t)
- pl.title('Image 1 Adapt')
-
-
- pl.subplot(2,3,3)
- pl.imshow(I1te)
- pl.title('Image 1 Adapt (reg)')
-
- pl.subplot(2,3,4)
-
- pl.imshow(I2)
- pl.title('Image 2')
-
- pl.subplot(2,3,5)
- pl.imshow(I2t)
- pl.title('Image 2 Adapt')
-
-
- pl.subplot(2,3,6)
- pl.imshow(I2te)
- pl.title('Image 2 Adapt (reg)')
-
- pl.show()
-
-**Total running time of the script:** ( 0 minutes 24.815 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: plot_OTDA_color_images.py <plot_OTDA_color_images.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: plot_OTDA_color_images.ipynb <plot_OTDA_color_images.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_OTDA_mapping.ipynb b/docs/source/auto_examples/plot_OTDA_mapping.ipynb
deleted file mode 100644
index ec405af..0000000
--- a/docs/source/auto_examples/plot_OTDA_mapping.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\n===============================================\nOT mapping estimation for domain adaptation [8]\n===============================================\n\n[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, \"Mapping estimation for\n discrete optimal transport\", Neural Information Processing Systems (NIPS), 2016.\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% dataset generation\n\nnp.random.seed(0) # makes example reproducible\n\nn=100 # nb samples in source and target datasets\ntheta=2*np.pi/20\nnz=0.1\nxs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz)\nxt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz)\n\n# one of the target mode changes its variance (no linear mapping)\nxt[yt==2]*=3\nxt=xt+4\n\n\n#%% plot samples\n\npl.figure(1,(8,5))\npl.clf()\n\npl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')\n\npl.legend(loc=0)\npl.title('Source and target distributions')\n\n\n\n#%% OT linear mapping estimation\n\neta=1e-8 # quadratic regularization for regression\nmu=1e0 # weight of the OT linear term\nbias=True # estimate a bias\n\not_mapping=ot.da.OTDA_mapping_linear()\not_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)\n\nxst=ot_mapping.predict(xs) # use the estimated mapping\nxst0=ot_mapping.interp() # use barycentric mapping\n\n\npl.figure(2,(10,7))\npl.clf()\npl.subplot(2,2,1)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping')\npl.title(\"barycentric mapping\")\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)\npl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')\npl.title(\"Learned mapping\")\n\n\n\n#%% Kernel mapping estimation\n\neta=1e-5 # quadratic regularization for regression\nmu=1e-1 # weight of the OT linear term\nbias=True # estimate a bias\nsigma=1 # sigma bandwidth fot gaussian kernel\n\n\not_mapping_kernel=ot.da.OTDA_mapping_kernel()\not_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)\n\nxst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping\nxst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping\n\n\n#%% Plotting the mapped samples\n\npl.figure(2,(10,7))\npl.clf()\npl.subplot(2,2,1)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)\npl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples')\npl.title(\"Bary. mapping (linear)\")\npl.legend(loc=0)\n\npl.subplot(2,2,2)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)\npl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')\npl.title(\"Estim. mapping (linear)\")\n\npl.subplot(2,2,3)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)\npl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping')\npl.title(\"Bary. mapping (kernel)\")\n\npl.subplot(2,2,4)\npl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)\npl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping')\npl.title(\"Estim. mapping (kernel)\")"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OTDA_mapping.py b/docs/source/auto_examples/plot_OTDA_mapping.py
deleted file mode 100644
index 78b57e7..0000000
--- a/docs/source/auto_examples/plot_OTDA_mapping.py
+++ /dev/null
@@ -1,110 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-===============================================
-OT mapping estimation for domain adaptation [8]
-===============================================
-
-[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
- discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
-"""
-
-import numpy as np
-import matplotlib.pylab as pl
-import ot
-
-
-
-#%% dataset generation
-
-np.random.seed(0) # makes example reproducible
-
-n=100 # nb samples in source and target datasets
-theta=2*np.pi/20
-nz=0.1
-xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz)
-xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz)
-
-# one of the target mode changes its variance (no linear mapping)
-xt[yt==2]*=3
-xt=xt+4
-
-
-#%% plot samples
-
-pl.figure(1,(8,5))
-pl.clf()
-
-pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-
-pl.legend(loc=0)
-pl.title('Source and target distributions')
-
-
-
-#%% OT linear mapping estimation
-
-eta=1e-8 # quadratic regularization for regression
-mu=1e0 # weight of the OT linear term
-bias=True # estimate a bias
-
-ot_mapping=ot.da.OTDA_mapping_linear()
-ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)
-
-xst=ot_mapping.predict(xs) # use the estimated mapping
-xst0=ot_mapping.interp() # use barycentric mapping
-
-
-pl.figure(2,(10,7))
-pl.clf()
-pl.subplot(2,2,1)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
-pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping')
-pl.title("barycentric mapping")
-
-pl.subplot(2,2,2)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
-pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
-pl.title("Learned mapping")
-
-
-
-#%% Kernel mapping estimation
-
-eta=1e-5 # quadratic regularization for regression
-mu=1e-1 # weight of the OT linear term
-bias=True # estimate a bias
-sigma=1 # sigma bandwidth fot gaussian kernel
-
-
-ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
-ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)
-
-xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping
-xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping
-
-
-#%% Plotting the mapped samples
-
-pl.figure(2,(10,7))
-pl.clf()
-pl.subplot(2,2,1)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
-pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples')
-pl.title("Bary. mapping (linear)")
-pl.legend(loc=0)
-
-pl.subplot(2,2,2)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
-pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
-pl.title("Estim. mapping (linear)")
-
-pl.subplot(2,2,3)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
-pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping')
-pl.title("Bary. mapping (kernel)")
-
-pl.subplot(2,2,4)
-pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
-pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping')
-pl.title("Estim. mapping (kernel)")
diff --git a/docs/source/auto_examples/plot_OTDA_mapping.rst b/docs/source/auto_examples/plot_OTDA_mapping.rst
deleted file mode 100644
index 18da90d..0000000
--- a/docs/source/auto_examples/plot_OTDA_mapping.rst
+++ /dev/null
@@ -1,186 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_plot_OTDA_mapping.py:
-
-
-===============================================
-OT mapping estimation for domain adaptation [8]
-===============================================
-
-[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
- discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
-
-
-
-
-.. rst-class:: sphx-glr-horizontal
-
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_mapping_001.png
- :scale: 47
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_mapping_002.png
- :scale: 47
-
-
-.. rst-class:: sphx-glr-script-out
-
- Out::
-
- It. |Loss |Delta loss
- --------------------------------
- 0|4.009366e+03|0.000000e+00
- 1|3.999933e+03|-2.352753e-03
- 2|3.999520e+03|-1.031984e-04
- 3|3.999362e+03|-3.936391e-05
- 4|3.999281e+03|-2.032868e-05
- 5|3.999238e+03|-1.083083e-05
- 6|3.999229e+03|-2.125291e-06
- It. |Loss |Delta loss
- --------------------------------
- 0|4.026841e+02|0.000000e+00
- 1|3.990791e+02|-8.952439e-03
- 2|3.987954e+02|-7.107124e-04
- 3|3.986554e+02|-3.512453e-04
- 4|3.985721e+02|-2.087997e-04
- 5|3.985141e+02|-1.456184e-04
- 6|3.984729e+02|-1.034624e-04
- 7|3.984435e+02|-7.366943e-05
- 8|3.984199e+02|-5.922497e-05
- 9|3.984016e+02|-4.593063e-05
- 10|3.983867e+02|-3.733061e-05
-
-
-
-
-|
-
-
-.. code-block:: python
-
-
- import numpy as np
- import matplotlib.pylab as pl
- import ot
-
-
-
- #%% dataset generation
-
- np.random.seed(0) # makes example reproducible
-
- n=100 # nb samples in source and target datasets
- theta=2*np.pi/20
- nz=0.1
- xs,ys=ot.datasets.get_data_classif('gaussrot',n,nz=nz)
- xt,yt=ot.datasets.get_data_classif('gaussrot',n,theta=theta,nz=nz)
-
- # one of the target mode changes its variance (no linear mapping)
- xt[yt==2]*=3
- xt=xt+4
-
-
- #%% plot samples
-
- pl.figure(1,(8,5))
- pl.clf()
-
- pl.scatter(xs[:,0],xs[:,1],c=ys,marker='+',label='Source samples')
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples')
-
- pl.legend(loc=0)
- pl.title('Source and target distributions')
-
-
-
- #%% OT linear mapping estimation
-
- eta=1e-8 # quadratic regularization for regression
- mu=1e0 # weight of the OT linear term
- bias=True # estimate a bias
-
- ot_mapping=ot.da.OTDA_mapping_linear()
- ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)
-
- xst=ot_mapping.predict(xs) # use the estimated mapping
- xst0=ot_mapping.interp() # use barycentric mapping
-
-
- pl.figure(2,(10,7))
- pl.clf()
- pl.subplot(2,2,1)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
- pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='barycentric mapping')
- pl.title("barycentric mapping")
-
- pl.subplot(2,2,2)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.3)
- pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
- pl.title("Learned mapping")
-
-
-
- #%% Kernel mapping estimation
-
- eta=1e-5 # quadratic regularization for regression
- mu=1e-1 # weight of the OT linear term
- bias=True # estimate a bias
- sigma=1 # sigma bandwidth fot gaussian kernel
-
-
- ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
- ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)
-
- xst_kernel=ot_mapping_kernel.predict(xs) # use the estimated mapping
- xst0_kernel=ot_mapping_kernel.interp() # use barycentric mapping
-
-
- #%% Plotting the mapped samples
-
- pl.figure(2,(10,7))
- pl.clf()
- pl.subplot(2,2,1)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
- pl.scatter(xst0[:,0],xst0[:,1],c=ys,marker='+',label='Mapped source samples')
- pl.title("Bary. mapping (linear)")
- pl.legend(loc=0)
-
- pl.subplot(2,2,2)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
- pl.scatter(xst[:,0],xst[:,1],c=ys,marker='+',label='Learned mapping')
- pl.title("Estim. mapping (linear)")
-
- pl.subplot(2,2,3)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
- pl.scatter(xst0_kernel[:,0],xst0_kernel[:,1],c=ys,marker='+',label='barycentric mapping')
- pl.title("Bary. mapping (kernel)")
-
- pl.subplot(2,2,4)
- pl.scatter(xt[:,0],xt[:,1],c=yt,marker='o',label='Target samples',alpha=.2)
- pl.scatter(xst_kernel[:,0],xst_kernel[:,1],c=ys,marker='+',label='Learned mapping')
- pl.title("Estim. mapping (kernel)")
-
-**Total running time of the script:** ( 0 minutes 0.882 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: plot_OTDA_mapping.py <plot_OTDA_mapping.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: plot_OTDA_mapping.ipynb <plot_OTDA_mapping.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_OTDA_mapping_color_images.ipynb b/docs/source/auto_examples/plot_OTDA_mapping_color_images.ipynb
deleted file mode 100644
index 1136cc3..0000000
--- a/docs/source/auto_examples/plot_OTDA_mapping_color_images.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\n====================================================================================\nOT for domain adaptation with image color adaptation [6] with mapping estimation [8]\n====================================================================================\n\n[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized\n discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.\n[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, \"Mapping estimation for\n discrete optimal transport\", Neural Information Processing Systems (NIPS), 2016.\n\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import numpy as np\nimport scipy.ndimage as spi\nimport matplotlib.pylab as pl\nimport ot\n\n\n#%% Loading images\n\nI1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256\nI2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256\n\n#%% Plot images\n\npl.figure(1)\n\npl.subplot(1,2,1)\npl.imshow(I1)\npl.title('Image 1')\n\npl.subplot(1,2,2)\npl.imshow(I2)\npl.title('Image 2')\n\npl.show()\n\n#%% Image conversion and dataset generation\n\ndef im2mat(I):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))\n\ndef mat2im(X,shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\nX1=im2mat(I1)\nX2=im2mat(I2)\n\n# training samples\nnb=1000\nidx1=np.random.randint(X1.shape[0],size=(nb,))\nidx2=np.random.randint(X2.shape[0],size=(nb,))\n\nxs=X1[idx1,:]\nxt=X2[idx2,:]\n\n#%% Plot image distributions\n\n\npl.figure(2,(10,5))\n\npl.subplot(1,2,1)\npl.scatter(xs[:,0],xs[:,2],c=xs)\npl.axis([0,1,0,1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 1')\n\npl.subplot(1,2,2)\n#pl.imshow(I2)\npl.scatter(xt[:,0],xt[:,2],c=xt)\npl.axis([0,1,0,1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 2')\n\npl.show()\n\n\n\n#%% domain adaptation between images\ndef minmax(I):\n return np.minimum(np.maximum(I,0),1)\n# LP problem\nda_emd=ot.da.OTDA() # init class\nda_emd.fit(xs,xt) # fit distributions\n\nX1t=da_emd.predict(X1) # out of sample\nI1t=minmax(mat2im(X1t,I1.shape))\n\n# sinkhorn regularization\nlambd=1e-1\nda_entrop=ot.da.OTDA_sinkhorn()\nda_entrop.fit(xs,xt,reg=lambd)\n\nX1te=da_entrop.predict(X1)\nI1te=minmax(mat2im(X1te,I1.shape))\n\n# linear mapping estimation\neta=1e-8 # quadratic regularization for regression\nmu=1e0 # weight of the OT linear term\nbias=True # estimate a bias\n\not_mapping=ot.da.OTDA_mapping_linear()\not_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)\n\nX1tl=ot_mapping.predict(X1) # use the estimated mapping\nI1tl=minmax(mat2im(X1tl,I1.shape))\n\n# nonlinear mapping estimation\neta=1e-2 # quadratic regularization for regression\nmu=1e0 # weight of the OT linear term\nbias=False # estimate a bias\nsigma=1 # sigma bandwidth fot gaussian kernel\n\n\not_mapping_kernel=ot.da.OTDA_mapping_kernel()\not_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)\n\nX1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping\nI1tn=minmax(mat2im(X1tn,I1.shape))\n#%% plot images\n\n\npl.figure(2,(10,8))\n\npl.subplot(2,3,1)\n\npl.imshow(I1)\npl.title('Im. 1')\n\npl.subplot(2,3,2)\n\npl.imshow(I2)\npl.title('Im. 2')\n\n\npl.subplot(2,3,3)\npl.imshow(I1t)\npl.title('Im. 1 Interp LP')\n\npl.subplot(2,3,4)\npl.imshow(I1te)\npl.title('Im. 1 Interp Entrop')\n\n\npl.subplot(2,3,5)\npl.imshow(I1tl)\npl.title('Im. 1 Linear mapping')\n\npl.subplot(2,3,6)\npl.imshow(I1tn)\npl.title('Im. 1 nonlinear mapping')\n\npl.show()"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OTDA_mapping_color_images.py b/docs/source/auto_examples/plot_OTDA_mapping_color_images.py
deleted file mode 100644
index f07dc6c..0000000
--- a/docs/source/auto_examples/plot_OTDA_mapping_color_images.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-====================================================================================
-OT for domain adaptation with image color adaptation [6] with mapping estimation [8]
-====================================================================================
-
-[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized
- discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
-[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
- discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
-
-"""
-
-import numpy as np
-import scipy.ndimage as spi
-import matplotlib.pylab as pl
-import ot
-
-
-#%% Loading images
-
-I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256
-I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256
-
-#%% Plot images
-
-pl.figure(1)
-
-pl.subplot(1,2,1)
-pl.imshow(I1)
-pl.title('Image 1')
-
-pl.subplot(1,2,2)
-pl.imshow(I2)
-pl.title('Image 2')
-
-pl.show()
-
-#%% Image conversion and dataset generation
-
-def im2mat(I):
- """Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))
-
-def mat2im(X,shape):
- """Converts back a matrix to an image"""
- return X.reshape(shape)
-
-X1=im2mat(I1)
-X2=im2mat(I2)
-
-# training samples
-nb=1000
-idx1=np.random.randint(X1.shape[0],size=(nb,))
-idx2=np.random.randint(X2.shape[0],size=(nb,))
-
-xs=X1[idx1,:]
-xt=X2[idx2,:]
-
-#%% Plot image distributions
-
-
-pl.figure(2,(10,5))
-
-pl.subplot(1,2,1)
-pl.scatter(xs[:,0],xs[:,2],c=xs)
-pl.axis([0,1,0,1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 1')
-
-pl.subplot(1,2,2)
-#pl.imshow(I2)
-pl.scatter(xt[:,0],xt[:,2],c=xt)
-pl.axis([0,1,0,1])
-pl.xlabel('Red')
-pl.ylabel('Blue')
-pl.title('Image 2')
-
-pl.show()
-
-
-
-#%% domain adaptation between images
-def minmax(I):
- return np.minimum(np.maximum(I,0),1)
-# LP problem
-da_emd=ot.da.OTDA() # init class
-da_emd.fit(xs,xt) # fit distributions
-
-X1t=da_emd.predict(X1) # out of sample
-I1t=minmax(mat2im(X1t,I1.shape))
-
-# sinkhorn regularization
-lambd=1e-1
-da_entrop=ot.da.OTDA_sinkhorn()
-da_entrop.fit(xs,xt,reg=lambd)
-
-X1te=da_entrop.predict(X1)
-I1te=minmax(mat2im(X1te,I1.shape))
-
-# linear mapping estimation
-eta=1e-8 # quadratic regularization for regression
-mu=1e0 # weight of the OT linear term
-bias=True # estimate a bias
-
-ot_mapping=ot.da.OTDA_mapping_linear()
-ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)
-
-X1tl=ot_mapping.predict(X1) # use the estimated mapping
-I1tl=minmax(mat2im(X1tl,I1.shape))
-
-# nonlinear mapping estimation
-eta=1e-2 # quadratic regularization for regression
-mu=1e0 # weight of the OT linear term
-bias=False # estimate a bias
-sigma=1 # sigma bandwidth fot gaussian kernel
-
-
-ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
-ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)
-
-X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping
-I1tn=minmax(mat2im(X1tn,I1.shape))
-#%% plot images
-
-
-pl.figure(2,(10,8))
-
-pl.subplot(2,3,1)
-
-pl.imshow(I1)
-pl.title('Im. 1')
-
-pl.subplot(2,3,2)
-
-pl.imshow(I2)
-pl.title('Im. 2')
-
-
-pl.subplot(2,3,3)
-pl.imshow(I1t)
-pl.title('Im. 1 Interp LP')
-
-pl.subplot(2,3,4)
-pl.imshow(I1te)
-pl.title('Im. 1 Interp Entrop')
-
-
-pl.subplot(2,3,5)
-pl.imshow(I1tl)
-pl.title('Im. 1 Linear mapping')
-
-pl.subplot(2,3,6)
-pl.imshow(I1tn)
-pl.title('Im. 1 nonlinear mapping')
-
-pl.show()
diff --git a/docs/source/auto_examples/plot_OTDA_mapping_color_images.rst b/docs/source/auto_examples/plot_OTDA_mapping_color_images.rst
deleted file mode 100644
index 60be3a4..0000000
--- a/docs/source/auto_examples/plot_OTDA_mapping_color_images.rst
+++ /dev/null
@@ -1,246 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_plot_OTDA_mapping_color_images.py:
-
-
-====================================================================================
-OT for domain adaptation with image color adaptation [6] with mapping estimation [8]
-====================================================================================
-
-[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized
- discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
-[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
- discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
-
-
-
-
-
-.. rst-class:: sphx-glr-horizontal
-
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_001.png
- :scale: 47
-
- *
-
- .. image:: /auto_examples/images/sphx_glr_plot_OTDA_mapping_color_images_002.png
- :scale: 47
-
-
-.. rst-class:: sphx-glr-script-out
-
- Out::
-
- It. |Loss |Delta loss
- --------------------------------
- 0|3.624802e+02|0.000000e+00
- 1|3.547180e+02|-2.141395e-02
- 2|3.545494e+02|-4.753955e-04
- 3|3.544646e+02|-2.391784e-04
- 4|3.544126e+02|-1.466280e-04
- 5|3.543775e+02|-9.921805e-05
- 6|3.543518e+02|-7.245828e-05
- 7|3.543323e+02|-5.491924e-05
- 8|3.543170e+02|-4.342401e-05
- 9|3.543046e+02|-3.472174e-05
- 10|3.542945e+02|-2.878681e-05
- 11|3.542859e+02|-2.417065e-05
- 12|3.542786e+02|-2.058131e-05
- 13|3.542723e+02|-1.768262e-05
- 14|3.542668e+02|-1.551616e-05
- 15|3.542620e+02|-1.371909e-05
- 16|3.542577e+02|-1.213326e-05
- 17|3.542538e+02|-1.085481e-05
- 18|3.542531e+02|-1.996006e-06
- It. |Loss |Delta loss
- --------------------------------
- 0|3.555768e+02|0.000000e+00
- 1|3.510071e+02|-1.285164e-02
- 2|3.509110e+02|-2.736701e-04
- 3|3.508748e+02|-1.031476e-04
- 4|3.508506e+02|-6.910585e-05
- 5|3.508330e+02|-5.014608e-05
- 6|3.508195e+02|-3.839166e-05
- 7|3.508090e+02|-3.004218e-05
- 8|3.508005e+02|-2.417627e-05
- 9|3.507935e+02|-2.004621e-05
- 10|3.507876e+02|-1.681731e-05
-
-
-
-
-|
-
-
-.. code-block:: python
-
-
- import numpy as np
- import scipy.ndimage as spi
- import matplotlib.pylab as pl
- import ot
-
-
- #%% Loading images
-
- I1=spi.imread('../data/ocean_day.jpg').astype(np.float64)/256
- I2=spi.imread('../data/ocean_sunset.jpg').astype(np.float64)/256
-
- #%% Plot images
-
- pl.figure(1)
-
- pl.subplot(1,2,1)
- pl.imshow(I1)
- pl.title('Image 1')
-
- pl.subplot(1,2,2)
- pl.imshow(I2)
- pl.title('Image 2')
-
- pl.show()
-
- #%% Image conversion and dataset generation
-
- def im2mat(I):
- """Converts and image to matrix (one pixel per line)"""
- return I.reshape((I.shape[0]*I.shape[1],I.shape[2]))
-
- def mat2im(X,shape):
- """Converts back a matrix to an image"""
- return X.reshape(shape)
-
- X1=im2mat(I1)
- X2=im2mat(I2)
-
- # training samples
- nb=1000
- idx1=np.random.randint(X1.shape[0],size=(nb,))
- idx2=np.random.randint(X2.shape[0],size=(nb,))
-
- xs=X1[idx1,:]
- xt=X2[idx2,:]
-
- #%% Plot image distributions
-
-
- pl.figure(2,(10,5))
-
- pl.subplot(1,2,1)
- pl.scatter(xs[:,0],xs[:,2],c=xs)
- pl.axis([0,1,0,1])
- pl.xlabel('Red')
- pl.ylabel('Blue')
- pl.title('Image 1')
-
- pl.subplot(1,2,2)
- #pl.imshow(I2)
- pl.scatter(xt[:,0],xt[:,2],c=xt)
- pl.axis([0,1,0,1])
- pl.xlabel('Red')
- pl.ylabel('Blue')
- pl.title('Image 2')
-
- pl.show()
-
-
-
- #%% domain adaptation between images
- def minmax(I):
- return np.minimum(np.maximum(I,0),1)
- # LP problem
- da_emd=ot.da.OTDA() # init class
- da_emd.fit(xs,xt) # fit distributions
-
- X1t=da_emd.predict(X1) # out of sample
- I1t=minmax(mat2im(X1t,I1.shape))
-
- # sinkhorn regularization
- lambd=1e-1
- da_entrop=ot.da.OTDA_sinkhorn()
- da_entrop.fit(xs,xt,reg=lambd)
-
- X1te=da_entrop.predict(X1)
- I1te=minmax(mat2im(X1te,I1.shape))
-
- # linear mapping estimation
- eta=1e-8 # quadratic regularization for regression
- mu=1e0 # weight of the OT linear term
- bias=True # estimate a bias
-
- ot_mapping=ot.da.OTDA_mapping_linear()
- ot_mapping.fit(xs,xt,mu=mu,eta=eta,bias=bias,numItermax = 20,verbose=True)
-
- X1tl=ot_mapping.predict(X1) # use the estimated mapping
- I1tl=minmax(mat2im(X1tl,I1.shape))
-
- # nonlinear mapping estimation
- eta=1e-2 # quadratic regularization for regression
- mu=1e0 # weight of the OT linear term
- bias=False # estimate a bias
- sigma=1 # sigma bandwidth fot gaussian kernel
-
-
- ot_mapping_kernel=ot.da.OTDA_mapping_kernel()
- ot_mapping_kernel.fit(xs,xt,mu=mu,eta=eta,sigma=sigma,bias=bias,numItermax = 10,verbose=True)
-
- X1tn=ot_mapping_kernel.predict(X1) # use the estimated mapping
- I1tn=minmax(mat2im(X1tn,I1.shape))
- #%% plot images
-
-
- pl.figure(2,(10,8))
-
- pl.subplot(2,3,1)
-
- pl.imshow(I1)
- pl.title('Im. 1')
-
- pl.subplot(2,3,2)
-
- pl.imshow(I2)
- pl.title('Im. 2')
-
-
- pl.subplot(2,3,3)
- pl.imshow(I1t)
- pl.title('Im. 1 Interp LP')
-
- pl.subplot(2,3,4)
- pl.imshow(I1te)
- pl.title('Im. 1 Interp Entrop')
-
-
- pl.subplot(2,3,5)
- pl.imshow(I1tl)
- pl.title('Im. 1 Linear mapping')
-
- pl.subplot(2,3,6)
- pl.imshow(I1tn)
- pl.title('Im. 1 nonlinear mapping')
-
- pl.show()
-
-**Total running time of the script:** ( 1 minutes 59.537 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: plot_OTDA_mapping_color_images.py <plot_OTDA_mapping_color_images.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: plot_OTDA_mapping_color_images.ipynb <plot_OTDA_mapping_color_images.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_OT_1D.ipynb b/docs/source/auto_examples/plot_OT_1D.ipynb
index 8715b97..26748c2 100644
--- a/docs/source/auto_examples/plot_OT_1D.ipynb
+++ b/docs/source/auto_examples/plot_OT_1D.ipynb
@@ -15,7 +15,7 @@
},
{
"source": [
- "\n# 1D optimal transport\n\n\n@author: rflamary\n\n"
+ "\n# 1D optimal transport\n\n\nThis example illustrates the computation of EMD and Sinkhorn transport plans\nand their visualization.\n\n\n"
],
"cell_type": "markdown",
"metadata": {}
@@ -24,7 +24,79 @@
"execution_count": null,
"cell_type": "code",
"source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=gauss(n,m=20,s=5) # m= mean, s= std\nb=gauss(n,m=60,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\npl.plot(x,a,'b',label='Source distribution')\npl.plot(x,b,'r',label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2)\not.plot.plot1D_mat(a,b,M,'Cost matrix M')\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Sinkhorn\n\nlambd=1e-3\nGs=ot.sinkhorn(a,b,M,lambd,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')"
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot distributions and loss matrix\n----------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Solve EMD\n---------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% EMD\n\nG0 = ot.emd(a, b, M)\n\npl.figure(3, figsize=(5, 5))\not.plot.plot1D_mat(a, b, G0, 'OT matrix G0')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Solve Sinkhorn\n--------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% Sinkhorn\n\nlambd = 1e-3\nGs = ot.sinkhorn(a, b, M, lambd, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')\n\npl.show()"
],
"outputs": [],
"metadata": {
diff --git a/docs/source/auto_examples/plot_OT_1D.py b/docs/source/auto_examples/plot_OT_1D.py
index 6661aa3..719058f 100644
--- a/docs/source/auto_examples/plot_OT_1D.py
+++ b/docs/source/auto_examples/plot_OT_1D.py
@@ -4,53 +4,80 @@
1D optimal transport
====================
-@author: rflamary
+This example illustrates the computation of EMD and Sinkhorn transport plans
+and their visualization.
+
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
import ot
from ot.datasets import get_1D_gauss as gauss
+##############################################################################
+# Generate data
+# -------------
+
#%% parameters
-n=100 # nb bins
+n = 100 # nb bins
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
# Gaussian distributions
-a=gauss(n,m=20,s=5) # m= mean, s= std
-b=gauss(n,m=60,s=10)
+a = gauss(n, m=20, s=5) # m= mean, s= std
+b = gauss(n, m=60, s=10)
# loss matrix
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
-M/=M.max()
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+
+##############################################################################
+# Plot distributions and loss matrix
+# ----------------------------------
#%% plot the distributions
-pl.figure(1)
-pl.plot(x,a,'b',label='Source distribution')
-pl.plot(x,b,'r',label='Target distribution')
+pl.figure(1, figsize=(6.4, 3))
+pl.plot(x, a, 'b', label='Source distribution')
+pl.plot(x, b, 'r', label='Target distribution')
pl.legend()
#%% plot distributions and loss matrix
-pl.figure(2)
-ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
+pl.figure(2, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
+
+##############################################################################
+# Solve EMD
+# ---------
+
#%% EMD
-G0=ot.emd(a,b,M)
+G0 = ot.emd(a, b, M)
+
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+##############################################################################
+# Solve Sinkhorn
+# --------------
-pl.figure(3)
-ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
#%% Sinkhorn
-lambd=1e-3
-Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
+lambd = 1e-3
+Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
+
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
-pl.figure(4)
-ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
+pl.show()
diff --git a/docs/source/auto_examples/plot_OT_1D.rst b/docs/source/auto_examples/plot_OT_1D.rst
index 44b715b..975a923 100644
--- a/docs/source/auto_examples/plot_OT_1D.rst
+++ b/docs/source/auto_examples/plot_OT_1D.rst
@@ -7,7 +7,80 @@
1D optimal transport
====================
-@author: rflamary
+This example illustrates the computation of EMD and Sinkhorn transport plans
+and their visualization.
+
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ from ot.datasets import get_1D_gauss as gauss
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = gauss(n, m=20, s=5) # m= mean, s= std
+ b = gauss(n, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ M /= M.max()
+
+
+
+
+
+
+
+
+Plot distributions and loss matrix
+----------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% plot the distributions
+
+ pl.figure(1, figsize=(6.4, 3))
+ pl.plot(x, a, 'b', label='Source distribution')
+ pl.plot(x, b, 'r', label='Target distribution')
+ pl.legend()
+
+ #%% plot distributions and loss matrix
+
+ pl.figure(2, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
@@ -25,94 +98,80 @@
.. image:: /auto_examples/images/sphx_glr_plot_OT_1D_002.png
:scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_003.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_OT_1D_004.png
- :scale: 47
+Solve EMD
+---------
-.. rst-class:: sphx-glr-script-out
- Out::
+.. code-block:: python
- It. |Err
- -------------------
- 0|8.187970e-02|
- 10|3.460174e-02|
- 20|6.633335e-03|
- 30|9.797798e-04|
- 40|1.389606e-04|
- 50|1.959016e-05|
- 60|2.759079e-06|
- 70|3.885166e-07|
- 80|5.470605e-08|
- 90|7.702918e-09|
- 100|1.084609e-09|
- 110|1.527180e-10|
+ #%% EMD
+ G0 = ot.emd(a, b, M)
-|
+ pl.figure(3, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
-.. code-block:: python
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- from ot.datasets import get_1D_gauss as gauss
+.. image:: /auto_examples/images/sphx_glr_plot_OT_1D_005.png
+ :align: center
- #%% parameters
- n=100 # nb bins
- # bin positions
- x=np.arange(n,dtype=np.float64)
+Solve Sinkhorn
+--------------
- # Gaussian distributions
- a=gauss(n,m=20,s=5) # m= mean, s= std
- b=gauss(n,m=60,s=10)
- # loss matrix
- M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
- M/=M.max()
- #%% plot the distributions
+.. code-block:: python
- pl.figure(1)
- pl.plot(x,a,'b',label='Source distribution')
- pl.plot(x,b,'r',label='Target distribution')
- pl.legend()
- #%% plot distributions and loss matrix
- pl.figure(2)
- ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
+ #%% Sinkhorn
- #%% EMD
+ lambd = 1e-3
+ Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)
- G0=ot.emd(a,b,M)
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')
- pl.figure(3)
- ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
+ pl.show()
- #%% Sinkhorn
- lambd=1e-3
- Gs=ot.sinkhorn(a,b,M,lambd,verbose=True)
- pl.figure(4)
- ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')
+.. image:: /auto_examples/images/sphx_glr_plot_OT_1D_007.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Err
+ -------------------
+ 0|8.187970e-02|
+ 10|3.460174e-02|
+ 20|6.633335e-03|
+ 30|9.797798e-04|
+ 40|1.389606e-04|
+ 50|1.959016e-05|
+ 60|2.759079e-06|
+ 70|3.885166e-07|
+ 80|5.470605e-08|
+ 90|7.702918e-09|
+ 100|1.084609e-09|
+ 110|1.527180e-10|
+
-**Total running time of the script:** ( 0 minutes 0.674 seconds)
+**Total running time of the script:** ( 0 minutes 1.198 seconds)
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.ipynb b/docs/source/auto_examples/plot_OT_2D_samples.ipynb
index fad0467..41a37f3 100644
--- a/docs/source/auto_examples/plot_OT_2D_samples.ipynb
+++ b/docs/source/auto_examples/plot_OT_2D_samples.ipynb
@@ -15,7 +15,7 @@
},
{
"source": [
- "\n# 2D Optimal transport between empirical distributions\n\n\n@author: rflamary\n\n"
+ "\n# 2D Optimal transport between empirical distributions\n\n\nIllustration of 2D optimal transport between discributions that are weighted\nsum of diracs. The OT matrix is plotted with the samples.\n\n\n"
],
"cell_type": "markdown",
"metadata": {}
@@ -24,7 +24,79 @@
"execution_count": null,
"cell_type": "code",
"source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nn=50 # nb samples\n\nmu_s=np.array([0,0])\ncov_s=np.array([[1,0],[0,1]])\n\nmu_t=np.array([4,4])\ncov_t=np.array([[1,-.8],[-.8,1]])\n\nxs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)\nxt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)\n\na,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM=ot.dist(xs,xt)\nM/=M.max()\n\n#%% plot samples\n\npl.figure(1)\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('Source and traget distributions')\n\npl.figure(2)\npl.imshow(M,interpolation='nearest')\npl.title('Cost matrix M')\n\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\npl.imshow(G0,interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')\n\n\n#%% sinkhorn\n\n# reg term\nlambd=5e-4\n\nGs=ot.sinkhorn(a,b,M,lambd)\n\npl.figure(5)\npl.imshow(Gs,interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])\npl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\npl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')"
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% parameters and data generation\n\nn = 50 # nb samples\n\nmu_s = np.array([0, 0])\ncov_s = np.array([[1, 0], [0, 1]])\n\nmu_t = np.array([4, 4])\ncov_t = np.array([[1, -.8], [-.8, 1]])\n\nxs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)\nxt = ot.datasets.get_2D_samples_gauss(n, mu_t, cov_t)\n\na, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples\n\n# loss matrix\nM = ot.dist(xs, xt)\nM /= M.max()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot data\n---------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% plot samples\n\npl.figure(1)\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('Source and target distributions')\n\npl.figure(2)\npl.imshow(M, interpolation='nearest')\npl.title('Cost matrix M')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute EMD\n-----------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% EMD\n\nG0 = ot.emd(a, b, M)\n\npl.figure(3)\npl.imshow(G0, interpolation='nearest')\npl.title('OT matrix G0')\n\npl.figure(4)\not.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix with samples')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute Sinkhorn\n----------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% sinkhorn\n\n# reg term\nlambd = 1e-3\n\nGs = ot.sinkhorn(a, b, M, lambd)\n\npl.figure(5)\npl.imshow(Gs, interpolation='nearest')\npl.title('OT matrix sinkhorn')\n\npl.figure(6)\not.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.legend(loc=0)\npl.title('OT matrix Sinkhorn with samples')\n\npl.show()"
],
"outputs": [],
"metadata": {
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.py b/docs/source/auto_examples/plot_OT_2D_samples.py
index edfb781..9818ec5 100644
--- a/docs/source/auto_examples/plot_OT_2D_samples.py
+++ b/docs/source/auto_examples/plot_OT_2D_samples.py
@@ -4,75 +4,98 @@
2D Optimal transport between empirical distributions
====================================================
-@author: rflamary
+Illustration of 2D optimal transport between discributions that are weighted
+sum of diracs. The OT matrix is plotted with the samples.
+
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
import ot
+##############################################################################
+# Generate data
+# -------------
+
#%% parameters and data generation
-n=50 # nb samples
+n = 50 # nb samples
-mu_s=np.array([0,0])
-cov_s=np.array([[1,0],[0,1]])
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
-mu_t=np.array([4,4])
-cov_t=np.array([[1,-.8],[-.8,1]])
+mu_t = np.array([4, 4])
+cov_t = np.array([[1, -.8], [-.8, 1]])
-xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
-xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)
+xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+xt = ot.datasets.get_2D_samples_gauss(n, mu_t, cov_t)
-a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
+a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
# loss matrix
-M=ot.dist(xs,xt)
-M/=M.max()
+M = ot.dist(xs, xt)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
#%% plot samples
pl.figure(1)
-pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
-pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
-pl.title('Source and traget distributions')
+pl.title('Source and target distributions')
pl.figure(2)
-pl.imshow(M,interpolation='nearest')
+pl.imshow(M, interpolation='nearest')
pl.title('Cost matrix M')
+##############################################################################
+# Compute EMD
+# -----------
#%% EMD
-G0=ot.emd(a,b,M)
+G0 = ot.emd(a, b, M)
pl.figure(3)
-pl.imshow(G0,interpolation='nearest')
+pl.imshow(G0, interpolation='nearest')
pl.title('OT matrix G0')
pl.figure(4)
-ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
-pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
-pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('OT matrix with samples')
+##############################################################################
+# Compute Sinkhorn
+# ----------------
+
#%% sinkhorn
# reg term
-lambd=5e-4
+lambd = 1e-3
-Gs=ot.sinkhorn(a,b,M,lambd)
+Gs = ot.sinkhorn(a, b, M, lambd)
pl.figure(5)
-pl.imshow(Gs,interpolation='nearest')
+pl.imshow(Gs, interpolation='nearest')
pl.title('OT matrix sinkhorn')
pl.figure(6)
-ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])
-pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
-pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('OT matrix Sinkhorn with samples')
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_OT_2D_samples.rst b/docs/source/auto_examples/plot_OT_2D_samples.rst
index e05e591..5565c54 100644
--- a/docs/source/auto_examples/plot_OT_2D_samples.rst
+++ b/docs/source/auto_examples/plot_OT_2D_samples.rst
@@ -7,131 +7,191 @@
2D Optimal transport between empirical distributions
====================================================
-@author: rflamary
+Illustration of 2D optimal transport between discributions that are weighted
+sum of diracs. The OT matrix is plotted with the samples.
-.. rst-class:: sphx-glr-horizontal
+.. code-block:: python
- *
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
- .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
- :scale: 47
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
- *
- .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_003.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_004.png
- :scale: 47
- *
+Generate data
+-------------
- .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
- :scale: 47
+.. code-block:: python
-.. rst-class:: sphx-glr-script-out
+ #%% parameters and data generation
- Out::
+ n = 50 # nb samples
- ('Warning: numerical errors at iteration', 0)
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+ mu_t = np.array([4, 4])
+ cov_t = np.array([[1, -.8], [-.8, 1]])
+ xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
+ xt = ot.datasets.get_2D_samples_gauss(n, mu_t, cov_t)
+ a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
-|
+ # loss matrix
+ M = ot.dist(xs, xt)
+ M /= M.max()
-.. code-block:: python
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- #%% parameters and data generation
- n=50 # nb samples
- mu_s=np.array([0,0])
- cov_s=np.array([[1,0],[0,1]])
+Plot data
+---------
- mu_t=np.array([4,4])
- cov_t=np.array([[1,-.8],[-.8,1]])
- xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
- xt=ot.datasets.get_2D_samples_gauss(n,mu_t,cov_t)
- a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
+.. code-block:: python
- # loss matrix
- M=ot.dist(xs,xt)
- M/=M.max()
#%% plot samples
pl.figure(1)
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
- pl.title('Source and traget distributions')
+ pl.title('Source and target distributions')
pl.figure(2)
- pl.imshow(M,interpolation='nearest')
+ pl.imshow(M, interpolation='nearest')
pl.title('Cost matrix M')
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_001.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_002.png
+ :scale: 47
+
+
+
+
+Compute EMD
+-----------
+
+
+
+.. code-block:: python
+
+
#%% EMD
- G0=ot.emd(a,b,M)
+ G0 = ot.emd(a, b, M)
pl.figure(3)
- pl.imshow(G0,interpolation='nearest')
+ pl.imshow(G0, interpolation='nearest')
pl.title('OT matrix G0')
pl.figure(4)
- ot.plot.plot2D_samples_mat(xs,xt,G0,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('OT matrix with samples')
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_005.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_006.png
+ :scale: 47
+
+
+
+
+Compute Sinkhorn
+----------------
+
+
+
+.. code-block:: python
+
+
#%% sinkhorn
# reg term
- lambd=5e-4
+ lambd = 1e-3
- Gs=ot.sinkhorn(a,b,M,lambd)
+ Gs = ot.sinkhorn(a, b, M, lambd)
pl.figure(5)
- pl.imshow(Gs,interpolation='nearest')
+ pl.imshow(Gs, interpolation='nearest')
pl.title('OT matrix sinkhorn')
pl.figure(6)
- ot.plot.plot2D_samples_mat(xs,xt,Gs,color=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
+ ot.plot.plot2D_samples_mat(xs, xt, Gs, color=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('OT matrix Sinkhorn with samples')
-**Total running time of the script:** ( 0 minutes 0.623 seconds)
+ pl.show()
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_009.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_2D_samples_010.png
+ :scale: 47
+
+
+
+
+**Total running time of the script:** ( 0 minutes 3.380 seconds)
diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb b/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb
index 46283ac..2b9a364 100644
--- a/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb
+++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.ipynb
@@ -15,7 +15,7 @@
},
{
"source": [
- "\n# 2D Optimal transport for different metrics\n\n\nStole the figure idea from Fig. 1 and 2 in \nhttps://arxiv.org/pdf/1706.07650.pdf\n\n\n@author: rflamary\n\n"
+ "\n# 2D Optimal transport for different metrics\n\n\n2D OT on empirical distributio with different gound metric.\n\nStole the figure idea from Fig. 1 and 2 in\nhttps://arxiv.org/pdf/1706.07650.pdf\n\n\n\n"
],
"cell_type": "markdown",
"metadata": {}
@@ -24,7 +24,79 @@
"execution_count": null,
"cell_type": "code",
"source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n#%% parameters and data generation\n\nfor data in range(2):\n\n if data:\n n=20 # nb samples\n xs=np.zeros((n,2))\n xs[:,0]=np.arange(n)+1\n xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...\n \n xt=np.zeros((n,2))\n xt[:,1]=np.arange(n)+1\n else:\n \n n=50 # nb samples\n xtot=np.zeros((n+1,2))\n xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)\n xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)\n \n xs=xtot[:n,:]\n xt=xtot[1:,:]\n \n \n \n a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples\n \n # loss matrix\n M1=ot.dist(xs,xt,metric='euclidean')\n M1/=M1.max()\n \n # loss matrix\n M2=ot.dist(xs,xt,metric='sqeuclidean')\n M2/=M2.max()\n \n # loss matrix\n Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))\n Mp/=Mp.max()\n \n #%% plot samples\n \n pl.figure(1+3*data)\n pl.clf()\n pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n pl.axis('equal')\n pl.title('Source and traget distributions')\n \n pl.figure(2+3*data,(15,5))\n pl.subplot(1,3,1)\n pl.imshow(M1,interpolation='nearest')\n pl.title('Eucidean cost')\n pl.subplot(1,3,2)\n pl.imshow(M2,interpolation='nearest')\n pl.title('Squared Euclidean cost')\n \n pl.subplot(1,3,3)\n pl.imshow(Mp,interpolation='nearest')\n pl.title('Sqrt Euclidean cost')\n #%% EMD\n \n G1=ot.emd(a,b,M1)\n G2=ot.emd(a,b,M2)\n Gp=ot.emd(a,b,Mp)\n \n pl.figure(3+3*data,(15,5))\n \n pl.subplot(1,3,1)\n ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])\n pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n pl.axis('equal')\n #pl.legend(loc=0)\n pl.title('OT Euclidean')\n \n pl.subplot(1,3,2)\n \n ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])\n pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n pl.axis('equal')\n #pl.legend(loc=0)\n pl.title('OT squared Euclidean')\n \n pl.subplot(1,3,3)\n \n ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])\n pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')\n pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')\n pl.axis('equal')\n #pl.legend(loc=0)\n pl.title('OT sqrt Euclidean')"
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Dataset 1 : uniform sampling\n----------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "n = 20 # nb samples\nxs = np.zeros((n, 2))\nxs[:, 0] = np.arange(n) + 1\nxs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...\n\nxt = np.zeros((n, 2))\nxt[:, 1] = np.arange(n) + 1\n\na, b = ot.unif(n), ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM1 = ot.dist(xs, xt, metric='euclidean')\nM1 /= M1.max()\n\n# loss matrix\nM2 = ot.dist(xs, xt, metric='sqeuclidean')\nM2 /= M2.max()\n\n# loss matrix\nMp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))\nMp /= Mp.max()\n\n# Data\npl.figure(1, figsize=(7, 3))\npl.clf()\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\npl.title('Source and traget distributions')\n\n\n# Cost matrices\npl.figure(2, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\npl.imshow(M1, interpolation='nearest')\npl.title('Euclidean cost')\n\npl.subplot(1, 3, 2)\npl.imshow(M2, interpolation='nearest')\npl.title('Squared Euclidean cost')\n\npl.subplot(1, 3, 3)\npl.imshow(Mp, interpolation='nearest')\npl.title('Sqrt Euclidean cost')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Dataset 1 : Plot OT Matrices\n----------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% EMD\nG1 = ot.emd(a, b, M1)\nG2 = ot.emd(a, b, M2)\nGp = ot.emd(a, b, Mp)\n\n# OT matrices\npl.figure(3, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\not.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT Euclidean')\n\npl.subplot(1, 3, 2)\not.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT squared Euclidean')\n\npl.subplot(1, 3, 3)\not.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT sqrt Euclidean')\npl.tight_layout()\n\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Dataset 2 : Partial circle\n--------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "n = 50 # nb samples\nxtot = np.zeros((n + 1, 2))\nxtot[:, 0] = np.cos(\n (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)\nxtot[:, 1] = np.sin(\n (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)\n\nxs = xtot[:n, :]\nxt = xtot[1:, :]\n\na, b = ot.unif(n), ot.unif(n) # uniform distribution on samples\n\n# loss matrix\nM1 = ot.dist(xs, xt, metric='euclidean')\nM1 /= M1.max()\n\n# loss matrix\nM2 = ot.dist(xs, xt, metric='sqeuclidean')\nM2 /= M2.max()\n\n# loss matrix\nMp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))\nMp /= Mp.max()\n\n\n# Data\npl.figure(4, figsize=(7, 3))\npl.clf()\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\npl.title('Source and traget distributions')\n\n\n# Cost matrices\npl.figure(5, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\npl.imshow(M1, interpolation='nearest')\npl.title('Euclidean cost')\n\npl.subplot(1, 3, 2)\npl.imshow(M2, interpolation='nearest')\npl.title('Squared Euclidean cost')\n\npl.subplot(1, 3, 3)\npl.imshow(Mp, interpolation='nearest')\npl.title('Sqrt Euclidean cost')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Dataset 2 : Plot OT Matrices\n-----------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% EMD\nG1 = ot.emd(a, b, M1)\nG2 = ot.emd(a, b, M2)\nGp = ot.emd(a, b, Mp)\n\n# OT matrices\npl.figure(6, figsize=(7, 3))\n\npl.subplot(1, 3, 1)\not.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT Euclidean')\n\npl.subplot(1, 3, 2)\not.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT squared Euclidean')\n\npl.subplot(1, 3, 3)\not.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])\npl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\npl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')\npl.axis('equal')\n# pl.legend(loc=0)\npl.title('OT sqrt Euclidean')\npl.tight_layout()\n\npl.show()"
],
"outputs": [],
"metadata": {
diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.py b/docs/source/auto_examples/plot_OT_L1_vs_L2.py
index 9bb92fe..090e809 100644
--- a/docs/source/auto_examples/plot_OT_L1_vs_L2.py
+++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.py
@@ -4,105 +4,204 @@
2D Optimal transport for different metrics
==========================================
-Stole the figure idea from Fig. 1 and 2 in
+2D OT on empirical distributio with different gound metric.
+
+Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
-@author: rflamary
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
import ot
-#%% parameters and data generation
-
-for data in range(2):
-
- if data:
- n=20 # nb samples
- xs=np.zeros((n,2))
- xs[:,0]=np.arange(n)+1
- xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...
-
- xt=np.zeros((n,2))
- xt[:,1]=np.arange(n)+1
- else:
-
- n=50 # nb samples
- xtot=np.zeros((n+1,2))
- xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
- xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
-
- xs=xtot[:n,:]
- xt=xtot[1:,:]
-
-
-
- a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
-
- # loss matrix
- M1=ot.dist(xs,xt,metric='euclidean')
- M1/=M1.max()
-
- # loss matrix
- M2=ot.dist(xs,xt,metric='sqeuclidean')
- M2/=M2.max()
-
- # loss matrix
- Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))
- Mp/=Mp.max()
-
- #%% plot samples
-
- pl.figure(1+3*data)
- pl.clf()
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- pl.title('Source and traget distributions')
-
- pl.figure(2+3*data,(15,5))
- pl.subplot(1,3,1)
- pl.imshow(M1,interpolation='nearest')
- pl.title('Eucidean cost')
- pl.subplot(1,3,2)
- pl.imshow(M2,interpolation='nearest')
- pl.title('Squared Euclidean cost')
-
- pl.subplot(1,3,3)
- pl.imshow(Mp,interpolation='nearest')
- pl.title('Sqrt Euclidean cost')
- #%% EMD
-
- G1=ot.emd(a,b,M1)
- G2=ot.emd(a,b,M2)
- Gp=ot.emd(a,b,Mp)
-
- pl.figure(3+3*data,(15,5))
-
- pl.subplot(1,3,1)
- ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- #pl.legend(loc=0)
- pl.title('OT Euclidean')
-
- pl.subplot(1,3,2)
-
- ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- #pl.legend(loc=0)
- pl.title('OT squared Euclidean')
-
- pl.subplot(1,3,3)
-
- ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- #pl.legend(loc=0)
- pl.title('OT sqrt Euclidean')
+##############################################################################
+# Dataset 1 : uniform sampling
+# ----------------------------
+
+n = 20 # nb samples
+xs = np.zeros((n, 2))
+xs[:, 0] = np.arange(n) + 1
+xs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...
+
+xt = np.zeros((n, 2))
+xt[:, 1] = np.arange(n) + 1
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M1 = ot.dist(xs, xt, metric='euclidean')
+M1 /= M1.max()
+
+# loss matrix
+M2 = ot.dist(xs, xt, metric='sqeuclidean')
+M2 /= M2.max()
+
+# loss matrix
+Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp /= Mp.max()
+
+# Data
+pl.figure(1, figsize=(7, 3))
+pl.clf()
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+pl.title('Source and traget distributions')
+
+
+# Cost matrices
+pl.figure(2, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+pl.imshow(M1, interpolation='nearest')
+pl.title('Euclidean cost')
+
+pl.subplot(1, 3, 2)
+pl.imshow(M2, interpolation='nearest')
+pl.title('Squared Euclidean cost')
+
+pl.subplot(1, 3, 3)
+pl.imshow(Mp, interpolation='nearest')
+pl.title('Sqrt Euclidean cost')
+pl.tight_layout()
+
+##############################################################################
+# Dataset 1 : Plot OT Matrices
+# ----------------------------
+
+
+#%% EMD
+G1 = ot.emd(a, b, M1)
+G2 = ot.emd(a, b, M2)
+Gp = ot.emd(a, b, Mp)
+
+# OT matrices
+pl.figure(3, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT Euclidean')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT squared Euclidean')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT sqrt Euclidean')
+pl.tight_layout()
+
+pl.show()
+
+
+##############################################################################
+# Dataset 2 : Partial circle
+# --------------------------
+
+n = 50 # nb samples
+xtot = np.zeros((n + 1, 2))
+xtot[:, 0] = np.cos(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+xtot[:, 1] = np.sin(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+
+xs = xtot[:n, :]
+xt = xtot[1:, :]
+
+a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+# loss matrix
+M1 = ot.dist(xs, xt, metric='euclidean')
+M1 /= M1.max()
+
+# loss matrix
+M2 = ot.dist(xs, xt, metric='sqeuclidean')
+M2 /= M2.max()
+
+# loss matrix
+Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+Mp /= Mp.max()
+
+
+# Data
+pl.figure(4, figsize=(7, 3))
+pl.clf()
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+pl.title('Source and traget distributions')
+
+
+# Cost matrices
+pl.figure(5, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+pl.imshow(M1, interpolation='nearest')
+pl.title('Euclidean cost')
+
+pl.subplot(1, 3, 2)
+pl.imshow(M2, interpolation='nearest')
+pl.title('Squared Euclidean cost')
+
+pl.subplot(1, 3, 3)
+pl.imshow(Mp, interpolation='nearest')
+pl.title('Sqrt Euclidean cost')
+pl.tight_layout()
+
+##############################################################################
+# Dataset 2 : Plot OT Matrices
+# -----------------------------
+
+
+#%% EMD
+G1 = ot.emd(a, b, M1)
+G2 = ot.emd(a, b, M2)
+Gp = ot.emd(a, b, Mp)
+
+# OT matrices
+pl.figure(6, figsize=(7, 3))
+
+pl.subplot(1, 3, 1)
+ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT Euclidean')
+
+pl.subplot(1, 3, 2)
+ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT squared Euclidean')
+
+pl.subplot(1, 3, 3)
+ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+pl.axis('equal')
+# pl.legend(loc=0)
+pl.title('OT sqrt Euclidean')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_OT_L1_vs_L2.rst b/docs/source/auto_examples/plot_OT_L1_vs_L2.rst
index 4e94bef..a569b50 100644
--- a/docs/source/auto_examples/plot_OT_L1_vs_L2.rst
+++ b/docs/source/auto_examples/plot_OT_L1_vs_L2.rst
@@ -7,11 +7,86 @@
2D Optimal transport for different metrics
==========================================
-Stole the figure idea from Fig. 1 and 2 in
+2D OT on empirical distributio with different gound metric.
+
+Stole the figure idea from Fig. 1 and 2 in
https://arxiv.org/pdf/1706.07650.pdf
-@author: rflamary
+
+
+
+.. code-block:: python
+
+
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+Dataset 1 : uniform sampling
+----------------------------
+
+
+
+.. code-block:: python
+
+
+ n = 20 # nb samples
+ xs = np.zeros((n, 2))
+ xs[:, 0] = np.arange(n) + 1
+ xs[:, 1] = (np.arange(n) + 1) * -0.001 # to make it strictly convex...
+
+ xt = np.zeros((n, 2))
+ xt[:, 1] = np.arange(n) + 1
+
+ a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+ # loss matrix
+ M1 = ot.dist(xs, xt, metric='euclidean')
+ M1 /= M1.max()
+
+ # loss matrix
+ M2 = ot.dist(xs, xt, metric='sqeuclidean')
+ M2 /= M2.max()
+
+ # loss matrix
+ Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+ Mp /= Mp.max()
+
+ # Data
+ pl.figure(1, figsize=(7, 3))
+ pl.clf()
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ pl.title('Source and traget distributions')
+
+
+ # Cost matrices
+ pl.figure(2, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ pl.imshow(M1, interpolation='nearest')
+ pl.title('Euclidean cost')
+
+ pl.subplot(1, 3, 2)
+ pl.imshow(M2, interpolation='nearest')
+ pl.title('Squared Euclidean cost')
+
+ pl.subplot(1, 3, 3)
+ pl.imshow(Mp, interpolation='nearest')
+ pl.title('Sqrt Euclidean cost')
+ pl.tight_layout()
@@ -29,130 +104,193 @@ https://arxiv.org/pdf/1706.07650.pdf
.. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_002.png
:scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_003.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_004.png
- :scale: 47
+Dataset 1 : Plot OT Matrices
+----------------------------
+
+
+
+.. code-block:: python
+
+
+
+ #%% EMD
+ G1 = ot.emd(a, b, M1)
+ G2 = ot.emd(a, b, M2)
+ Gp = ot.emd(a, b, Mp)
+
+ # OT matrices
+ pl.figure(3, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT Euclidean')
+
+ pl.subplot(1, 3, 2)
+ ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT squared Euclidean')
+
+ pl.subplot(1, 3, 3)
+ ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT sqrt Euclidean')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png
+ :align: center
+
+
+
+
+Dataset 2 : Partial circle
+--------------------------
+
+
+
+.. code-block:: python
+
+
+ n = 50 # nb samples
+ xtot = np.zeros((n + 1, 2))
+ xtot[:, 0] = np.cos(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+ xtot[:, 1] = np.sin(
+ (np.arange(n + 1) + 1.0) * 0.9 / (n + 2) * 2 * np.pi)
+
+ xs = xtot[:n, :]
+ xt = xtot[1:, :]
+
+ a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
+
+ # loss matrix
+ M1 = ot.dist(xs, xt, metric='euclidean')
+ M1 /= M1.max()
+
+ # loss matrix
+ M2 = ot.dist(xs, xt, metric='sqeuclidean')
+ M2 /= M2.max()
+
+ # loss matrix
+ Mp = np.sqrt(ot.dist(xs, xt, metric='euclidean'))
+ Mp /= Mp.max()
+
+
+ # Data
+ pl.figure(4, figsize=(7, 3))
+ pl.clf()
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ pl.title('Source and traget distributions')
+
+
+ # Cost matrices
+ pl.figure(5, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ pl.imshow(M1, interpolation='nearest')
+ pl.title('Euclidean cost')
+
+ pl.subplot(1, 3, 2)
+ pl.imshow(M2, interpolation='nearest')
+ pl.title('Squared Euclidean cost')
+
+ pl.subplot(1, 3, 3)
+ pl.imshow(Mp, interpolation='nearest')
+ pl.title('Sqrt Euclidean cost')
+ pl.tight_layout()
+
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
*
- .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_005.png
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_007.png
:scale: 47
*
- .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_006.png
+ .. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_008.png
:scale: 47
+Dataset 2 : Plot OT Matrices
+-----------------------------
+
+
.. code-block:: python
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- #%% parameters and data generation
-
- for data in range(2):
-
- if data:
- n=20 # nb samples
- xs=np.zeros((n,2))
- xs[:,0]=np.arange(n)+1
- xs[:,1]=(np.arange(n)+1)*-0.001 # to make it strictly convex...
-
- xt=np.zeros((n,2))
- xt[:,1]=np.arange(n)+1
- else:
-
- n=50 # nb samples
- xtot=np.zeros((n+1,2))
- xtot[:,0]=np.cos((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
- xtot[:,1]=np.sin((np.arange(n+1)+1.0)*0.9/(n+2)*2*np.pi)
-
- xs=xtot[:n,:]
- xt=xtot[1:,:]
-
-
-
- a,b = ot.unif(n),ot.unif(n) # uniform distribution on samples
-
- # loss matrix
- M1=ot.dist(xs,xt,metric='euclidean')
- M1/=M1.max()
-
- # loss matrix
- M2=ot.dist(xs,xt,metric='sqeuclidean')
- M2/=M2.max()
-
- # loss matrix
- Mp=np.sqrt(ot.dist(xs,xt,metric='euclidean'))
- Mp/=Mp.max()
-
- #%% plot samples
-
- pl.figure(1+3*data)
- pl.clf()
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- pl.title('Source and traget distributions')
-
- pl.figure(2+3*data,(15,5))
- pl.subplot(1,3,1)
- pl.imshow(M1,interpolation='nearest')
- pl.title('Eucidean cost')
- pl.subplot(1,3,2)
- pl.imshow(M2,interpolation='nearest')
- pl.title('Squared Euclidean cost')
-
- pl.subplot(1,3,3)
- pl.imshow(Mp,interpolation='nearest')
- pl.title('Sqrt Euclidean cost')
- #%% EMD
-
- G1=ot.emd(a,b,M1)
- G2=ot.emd(a,b,M2)
- Gp=ot.emd(a,b,Mp)
-
- pl.figure(3+3*data,(15,5))
-
- pl.subplot(1,3,1)
- ot.plot.plot2D_samples_mat(xs,xt,G1,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- #pl.legend(loc=0)
- pl.title('OT Euclidean')
-
- pl.subplot(1,3,2)
-
- ot.plot.plot2D_samples_mat(xs,xt,G2,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- #pl.legend(loc=0)
- pl.title('OT squared Euclidean')
-
- pl.subplot(1,3,3)
-
- ot.plot.plot2D_samples_mat(xs,xt,Gp,c=[.5,.5,1])
- pl.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
- pl.plot(xt[:,0],xt[:,1],'xr',label='Target samples')
- pl.axis('equal')
- #pl.legend(loc=0)
- pl.title('OT sqrt Euclidean')
-
-**Total running time of the script:** ( 0 minutes 1.417 seconds)
+ #%% EMD
+ G1 = ot.emd(a, b, M1)
+ G2 = ot.emd(a, b, M2)
+ Gp = ot.emd(a, b, Mp)
+
+ # OT matrices
+ pl.figure(6, figsize=(7, 3))
+
+ pl.subplot(1, 3, 1)
+ ot.plot.plot2D_samples_mat(xs, xt, G1, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT Euclidean')
+
+ pl.subplot(1, 3, 2)
+ ot.plot.plot2D_samples_mat(xs, xt, G2, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT squared Euclidean')
+
+ pl.subplot(1, 3, 3)
+ ot.plot.plot2D_samples_mat(xs, xt, Gp, c=[.5, .5, 1])
+ pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
+ pl.axis('equal')
+ # pl.legend(loc=0)
+ pl.title('OT sqrt Euclidean')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_OT_L1_vs_L2_011.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 1.976 seconds)
diff --git a/docs/source/auto_examples/plot_OT_conv.ipynb b/docs/source/auto_examples/plot_OT_conv.ipynb
deleted file mode 100644
index 7fc4af0..0000000
--- a/docs/source/auto_examples/plot_OT_conv.ipynb
+++ /dev/null
@@ -1,54 +0,0 @@
-{
- "nbformat_minor": 0,
- "nbformat": 4,
- "cells": [
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "%matplotlib inline"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- },
- {
- "source": [
- "\n# 1D Wasserstein barycenter demo\n\n\n\n@author: rflamary\n\n"
- ],
- "cell_type": "markdown",
- "metadata": {}
- },
- {
- "execution_count": null,
- "cell_type": "code",
- "source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used\nimport scipy as sp\nimport scipy.signal as sps\n#%% parameters\n\nn=10 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\nxx,yy=np.meshgrid(x,x)\n\n\nxpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1)))\n\nM=ot.dist(xpos)\n\n\nI0=((xx-5)**2+(yy-5)**2<3**2)*1.0\nI1=((xx-7)**2+(yy-7)**2<3**2)*1.0\n\nI0/=I0.sum()\nI1/=I1.sum()\n\ni0=I0.ravel()\ni1=I1.ravel()\n\nM=M[i0>0,:][:,i1>0].copy()\ni0=i0[i0>0]\ni1=i1[i1>0]\nItot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2)\n\n\n#%% plot the distributions\n\npl.figure(1)\npl.subplot(2,2,1)\npl.imshow(I0)\npl.subplot(2,2,2)\npl.imshow(I1)\n\n\n#%% barycenter computation\n\nalpha=0.5 # 0<=alpha<=1\nweights=np.array([1-alpha,alpha])\n\n\ndef conv2(I,k):\n return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0)\n\ndef conv2n(I,k):\n res=np.zeros_like(I)\n for i in range(I.shape[2]):\n res[:,:,i]=conv2(I[:,:,i],k)\n return res\n\n\ndef get_1Dkernel(reg,thr=1e-16,wmax=1024):\n w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3)\n x=np.arange(w,dtype=np.float64)\n return np.exp(-((x-w/2)**2)/reg)\n \nthr=1e-16\nreg=1e0\n\nk=get_1Dkernel(reg)\npl.figure(2)\npl.plot(k)\n\nI05=conv2(I0,k)\n\npl.figure(1)\npl.subplot(2,2,1)\npl.imshow(I0)\npl.subplot(2,2,2)\npl.imshow(I05)\n\n#%%\n\nG=ot.emd(i0,i1,M)\nr0=np.sum(M*G)\n\nreg=1e-1\nGs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg)\nrs=np.sum(M*Gs)\n\n#%%\n\ndef mylog(u):\n tmp=np.log(u)\n tmp[np.isnan(tmp)]=0\n return tmp\n\ndef sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):\n\n\n a=np.asarray(a,dtype=np.float64)\n b=np.asarray(b,dtype=np.float64)\n \n \n if len(b.shape)>2:\n nbb=b.shape[2]\n a=a[:,:,np.newaxis]\n else:\n nbb=0\n \n\n if log:\n log={'err':[]}\n\n # we assume that no distances are null except those of the diagonal of distances\n if nbb:\n u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2]))\n v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2]))\n a0=1.0/(np.prod(b.shape[:2]))\n else:\n u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2]))\n v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2]))\n a0=1.0/(np.prod(b.shape[:2]))\n \n \n k=get_1Dkernel(reg)\n \n if nbb:\n K=lambda I: conv2n(I,k)\n else:\n K=lambda I: conv2(I,k)\n\n cpt = 0\n err=1\n while (err>stopThr and cpt<numItermax):\n uprev = u\n vprev = v\n \n v = np.divide(b, K(u))\n u = np.divide(a, K(v))\n\n if (np.any(np.isnan(u)) or np.any(np.isnan(v)) \n or np.any(np.isinf(u)) or np.any(np.isinf(v))):\n # we have reached the machine precision\n # come back to previous solution and quit loop\n print('Warning: numerical errors at iteration', cpt)\n u = uprev\n v = vprev\n break\n if cpt%10==0:\n # we can speed up the process by checking for the error only all the 10th iterations\n\n err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)\n\n if log:\n log['err'].append(err)\n\n if verbose:\n if cpt%200 ==0:\n print('{:5s}|{:12s}'.format('It.','Err')+'\\n'+'-'*19)\n print('{:5d}|{:8e}|'.format(cpt,err))\n cpt = cpt +1\n if log:\n log['u']=u\n log['v']=v\n \n if nbb: #return only loss \n res=np.zeros((nbb))\n for i in range(nbb):\n res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)\n if log:\n return res,log\n else:\n return res \n \n else: # return OT matrix\n res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0)))\n if log:\n \n return res,log\n else:\n return res\n\nreg=1e0\nr,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True)\na=I0\nb=I1\nu=log['u']\nv=log['v']\n#%% barycenter interpolation"
- ],
- "outputs": [],
- "metadata": {
- "collapsed": false
- }
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 2",
- "name": "python2",
- "language": "python"
- },
- "language_info": {
- "mimetype": "text/x-python",
- "nbconvert_exporter": "python",
- "name": "python",
- "file_extension": ".py",
- "version": "2.7.12",
- "pygments_lexer": "ipython2",
- "codemirror_mode": {
- "version": 2,
- "name": "ipython"
- }
- }
- }
-} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_OT_conv.py b/docs/source/auto_examples/plot_OT_conv.py
deleted file mode 100644
index a86e7a2..0000000
--- a/docs/source/auto_examples/plot_OT_conv.py
+++ /dev/null
@@ -1,200 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-==============================
-1D Wasserstein barycenter demo
-==============================
-
-
-@author: rflamary
-"""
-
-import numpy as np
-import matplotlib.pylab as pl
-import ot
-from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
-import scipy as sp
-import scipy.signal as sps
-#%% parameters
-
-n=10 # nb bins
-
-# bin positions
-x=np.arange(n,dtype=np.float64)
-
-xx,yy=np.meshgrid(x,x)
-
-
-xpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1)))
-
-M=ot.dist(xpos)
-
-
-I0=((xx-5)**2+(yy-5)**2<3**2)*1.0
-I1=((xx-7)**2+(yy-7)**2<3**2)*1.0
-
-I0/=I0.sum()
-I1/=I1.sum()
-
-i0=I0.ravel()
-i1=I1.ravel()
-
-M=M[i0>0,:][:,i1>0].copy()
-i0=i0[i0>0]
-i1=i1[i1>0]
-Itot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2)
-
-
-#%% plot the distributions
-
-pl.figure(1)
-pl.subplot(2,2,1)
-pl.imshow(I0)
-pl.subplot(2,2,2)
-pl.imshow(I1)
-
-
-#%% barycenter computation
-
-alpha=0.5 # 0<=alpha<=1
-weights=np.array([1-alpha,alpha])
-
-
-def conv2(I,k):
- return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0)
-
-def conv2n(I,k):
- res=np.zeros_like(I)
- for i in range(I.shape[2]):
- res[:,:,i]=conv2(I[:,:,i],k)
- return res
-
-
-def get_1Dkernel(reg,thr=1e-16,wmax=1024):
- w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3)
- x=np.arange(w,dtype=np.float64)
- return np.exp(-((x-w/2)**2)/reg)
-
-thr=1e-16
-reg=1e0
-
-k=get_1Dkernel(reg)
-pl.figure(2)
-pl.plot(k)
-
-I05=conv2(I0,k)
-
-pl.figure(1)
-pl.subplot(2,2,1)
-pl.imshow(I0)
-pl.subplot(2,2,2)
-pl.imshow(I05)
-
-#%%
-
-G=ot.emd(i0,i1,M)
-r0=np.sum(M*G)
-
-reg=1e-1
-Gs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg)
-rs=np.sum(M*Gs)
-
-#%%
-
-def mylog(u):
- tmp=np.log(u)
- tmp[np.isnan(tmp)]=0
- return tmp
-
-def sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
-
-
- a=np.asarray(a,dtype=np.float64)
- b=np.asarray(b,dtype=np.float64)
-
-
- if len(b.shape)>2:
- nbb=b.shape[2]
- a=a[:,:,np.newaxis]
- else:
- nbb=0
-
-
- if log:
- log={'err':[]}
-
- # we assume that no distances are null except those of the diagonal of distances
- if nbb:
- u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2]))
- v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2]))
- a0=1.0/(np.prod(b.shape[:2]))
- else:
- u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2]))
- v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2]))
- a0=1.0/(np.prod(b.shape[:2]))
-
-
- k=get_1Dkernel(reg)
-
- if nbb:
- K=lambda I: conv2n(I,k)
- else:
- K=lambda I: conv2(I,k)
-
- cpt = 0
- err=1
- while (err>stopThr and cpt<numItermax):
- uprev = u
- vprev = v
-
- v = np.divide(b, K(u))
- u = np.divide(a, K(v))
-
- if (np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
- # we have reached the machine precision
- # come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
- u = uprev
- v = vprev
- break
- if cpt%10==0:
- # we can speed up the process by checking for the error only all the 10th iterations
-
- err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)
-
- if log:
- log['err'].append(err)
-
- if verbose:
- if cpt%200 ==0:
- print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
- cpt = cpt +1
- if log:
- log['u']=u
- log['v']=v
-
- if nbb: #return only loss
- res=np.zeros((nbb))
- for i in range(nbb):
- res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
- if log:
- return res,log
- else:
- return res
-
- else: # return OT matrix
- res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0)))
- if log:
-
- return res,log
- else:
- return res
-
-reg=1e0
-r,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True)
-a=I0
-b=I1
-u=log['u']
-v=log['v']
-#%% barycenter interpolation
diff --git a/docs/source/auto_examples/plot_OT_conv.rst b/docs/source/auto_examples/plot_OT_conv.rst
deleted file mode 100644
index 039bbdb..0000000
--- a/docs/source/auto_examples/plot_OT_conv.rst
+++ /dev/null
@@ -1,241 +0,0 @@
-
-
-.. _sphx_glr_auto_examples_plot_OT_conv.py:
-
-
-==============================
-1D Wasserstein barycenter demo
-==============================
-
-
-@author: rflamary
-
-
-
-
-.. code-block:: pytb
-
- Traceback (most recent call last):
- File "/home/rflamary/.local/lib/python2.7/site-packages/sphinx_gallery/gen_rst.py", line 518, in execute_code_block
- exec(code_block, example_globals)
- File "<string>", line 86, in <module>
- TypeError: unsupported operand type(s) for *: 'float' and 'Mock'
-
-
-
-
-
-.. code-block:: python
-
-
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
- import scipy as sp
- import scipy.signal as sps
- #%% parameters
-
- n=10 # nb bins
-
- # bin positions
- x=np.arange(n,dtype=np.float64)
-
- xx,yy=np.meshgrid(x,x)
-
-
- xpos=np.hstack((xx.reshape(-1,1),yy.reshape(-1,1)))
-
- M=ot.dist(xpos)
-
-
- I0=((xx-5)**2+(yy-5)**2<3**2)*1.0
- I1=((xx-7)**2+(yy-7)**2<3**2)*1.0
-
- I0/=I0.sum()
- I1/=I1.sum()
-
- i0=I0.ravel()
- i1=I1.ravel()
-
- M=M[i0>0,:][:,i1>0].copy()
- i0=i0[i0>0]
- i1=i1[i1>0]
- Itot=np.concatenate((I0[:,:,np.newaxis],I1[:,:,np.newaxis]),2)
-
-
- #%% plot the distributions
-
- pl.figure(1)
- pl.subplot(2,2,1)
- pl.imshow(I0)
- pl.subplot(2,2,2)
- pl.imshow(I1)
-
-
- #%% barycenter computation
-
- alpha=0.5 # 0<=alpha<=1
- weights=np.array([1-alpha,alpha])
-
-
- def conv2(I,k):
- return sp.ndimage.convolve1d(sp.ndimage.convolve1d(I,k,axis=1),k,axis=0)
-
- def conv2n(I,k):
- res=np.zeros_like(I)
- for i in range(I.shape[2]):
- res[:,:,i]=conv2(I[:,:,i],k)
- return res
-
-
- def get_1Dkernel(reg,thr=1e-16,wmax=1024):
- w=max(min(wmax,2*int((-np.log(thr)*reg)**(.5))),3)
- x=np.arange(w,dtype=np.float64)
- return np.exp(-((x-w/2)**2)/reg)
-
- thr=1e-16
- reg=1e0
-
- k=get_1Dkernel(reg)
- pl.figure(2)
- pl.plot(k)
-
- I05=conv2(I0,k)
-
- pl.figure(1)
- pl.subplot(2,2,1)
- pl.imshow(I0)
- pl.subplot(2,2,2)
- pl.imshow(I05)
-
- #%%
-
- G=ot.emd(i0,i1,M)
- r0=np.sum(M*G)
-
- reg=1e-1
- Gs=ot.bregman.sinkhorn_knopp(i0,i1,M,reg=reg)
- rs=np.sum(M*Gs)
-
- #%%
-
- def mylog(u):
- tmp=np.log(u)
- tmp[np.isnan(tmp)]=0
- return tmp
-
- def sinkhorn_conv(a,b, reg, numItermax = 1000, stopThr=1e-9, verbose=False, log=False,**kwargs):
-
-
- a=np.asarray(a,dtype=np.float64)
- b=np.asarray(b,dtype=np.float64)
-
-
- if len(b.shape)>2:
- nbb=b.shape[2]
- a=a[:,:,np.newaxis]
- else:
- nbb=0
-
-
- if log:
- log={'err':[]}
-
- # we assume that no distances are null except those of the diagonal of distances
- if nbb:
- u = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(a.shape[:2]))
- v = np.ones((a.shape[0],a.shape[1],nbb))/(np.prod(b.shape[:2]))
- a0=1.0/(np.prod(b.shape[:2]))
- else:
- u = np.ones((a.shape[0],a.shape[1]))/(np.prod(a.shape[:2]))
- v = np.ones((a.shape[0],a.shape[1]))/(np.prod(b.shape[:2]))
- a0=1.0/(np.prod(b.shape[:2]))
-
-
- k=get_1Dkernel(reg)
-
- if nbb:
- K=lambda I: conv2n(I,k)
- else:
- K=lambda I: conv2(I,k)
-
- cpt = 0
- err=1
- while (err>stopThr and cpt<numItermax):
- uprev = u
- vprev = v
-
- v = np.divide(b, K(u))
- u = np.divide(a, K(v))
-
- if (np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
- # we have reached the machine precision
- # come back to previous solution and quit loop
- print('Warning: numerical errors at iteration', cpt)
- u = uprev
- v = vprev
- break
- if cpt%10==0:
- # we can speed up the process by checking for the error only all the 10th iterations
-
- err = np.sum((u-uprev)**2)/np.sum((u)**2)+np.sum((v-vprev)**2)/np.sum((v)**2)
-
- if log:
- log['err'].append(err)
-
- if verbose:
- if cpt%200 ==0:
- print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
- print('{:5d}|{:8e}|'.format(cpt,err))
- cpt = cpt +1
- if log:
- log['u']=u
- log['v']=v
-
- if nbb: #return only loss
- res=np.zeros((nbb))
- for i in range(nbb):
- res[i]=np.sum(u[:,i].reshape((-1,1))*K*v[:,i].reshape((1,-1))*M)
- if log:
- return res,log
- else:
- return res
-
- else: # return OT matrix
- res=reg*a0*np.sum(a*mylog(u+(u==0))+b*mylog(v+(v==0)))
- if log:
-
- return res,log
- else:
- return res
-
- reg=1e0
- r,log=sinkhorn_conv(I0,I1,reg,verbose=True,log=True)
- a=I0
- b=I1
- u=log['u']
- v=log['v']
- #%% barycenter interpolation
-
-**Total running time of the script:** ( 0 minutes 0.000 seconds)
-
-
-
-.. container:: sphx-glr-footer
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Python source code: plot_OT_conv.py <plot_OT_conv.py>`
-
-
-
- .. container:: sphx-glr-download
-
- :download:`Download Jupyter notebook: plot_OT_conv.ipynb <plot_OT_conv.ipynb>`
-
-.. rst-class:: sphx-glr-signature
-
- `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_WDA.ipynb b/docs/source/auto_examples/plot_WDA.ipynb
index 408a605..1661c53 100644
--- a/docs/source/auto_examples/plot_WDA.ipynb
+++ b/docs/source/auto_examples/plot_WDA.ipynb
@@ -15,7 +15,7 @@
},
{
"source": [
- "\n# Wasserstein Discriminant Analysis\n\n\n@author: rflamary\n\n"
+ "\n# Wasserstein Discriminant Analysis\n\n\nThis example illustrate the use of WDA as proposed in [11].\n\n\n[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).\nWasserstein Discriminant Analysis.\n\n\n"
],
"cell_type": "markdown",
"metadata": {}
@@ -24,7 +24,97 @@
"execution_count": null,
"cell_type": "code",
"source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\nfrom ot.dr import wda\n\n\n#%% parameters\n\nn=1000 # nb samples in source and target datasets\nnz=0.2\nxs,ys=ot.datasets.get_data_classif('3gauss',n,nz)\nxt,yt=ot.datasets.get_data_classif('3gauss',n,nz)\n\nnbnoise=8\n\nxs=np.hstack((xs,np.random.randn(n,nbnoise)))\nxt=np.hstack((xt,np.random.randn(n,nbnoise)))\n\n#%% plot samples\n\npl.figure(1)\n\n\npl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')\npl.legend(loc=0)\npl.title('Discriminant dimensions')\n\n\n#%% plot distributions and loss matrix\np=2\nreg=1\nk=10\nmaxiter=100\n\nP,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)\n\n#%% plot samples\n\nxsp=proj(xs)\nxtp=proj(xt)\n\npl.figure(1,(10,5))\n\npl.subplot(1,2,1)\npl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples')\n\n\npl.subplot(1,2,2)\npl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples')"
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\n\nfrom ot.dr import wda, fda"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% parameters\n\nn = 1000 # nb samples in source and target datasets\nnz = 0.2\n\n# generate circle dataset\nt = np.random.rand(n) * 2 * np.pi\nys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxs = np.concatenate(\n (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nt = np.random.rand(n) * 2 * np.pi\nyt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1\nxt = np.concatenate(\n (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)\nxt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)\n\nnbnoise = 8\n\nxs = np.hstack((xs, np.random.randn(n, nbnoise)))\nxt = np.hstack((xt, np.random.randn(n, nbnoise)))"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot data\n---------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% plot samples\npl.figure(1, figsize=(6.4, 3.5))\n\npl.subplot(1, 2, 1)\npl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')\npl.legend(loc=0)\npl.title('Discriminant dimensions')\n\npl.subplot(1, 2, 2)\npl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')\npl.legend(loc=0)\npl.title('Other dimensions')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute Fisher Discriminant Analysis\n------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% Compute FDA\np = 2\n\nPfda, projfda = fda(xs, ys, p)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute Wasserstein Discriminant Analysis\n-----------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% Compute WDA\np = 2\nreg = 1e0\nk = 10\nmaxiter = 100\n\nPwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot 2D projections\n-------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% plot samples\n\nxsp = projfda(xs)\nxtp = projfda(xt)\n\nxspw = projwda(xs)\nxtpw = projwda(xt)\n\npl.figure(2)\n\npl.subplot(2, 2, 1)\npl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples FDA')\n\npl.subplot(2, 2, 2)\npl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples FDA')\n\npl.subplot(2, 2, 3)\npl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected training samples WDA')\n\npl.subplot(2, 2, 4)\npl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')\npl.legend(loc=0)\npl.title('Projected test samples WDA')\npl.tight_layout()\n\npl.show()"
],
"outputs": [],
"metadata": {
diff --git a/docs/source/auto_examples/plot_WDA.py b/docs/source/auto_examples/plot_WDA.py
index bbe3888..93cc237 100644
--- a/docs/source/auto_examples/plot_WDA.py
+++ b/docs/source/auto_examples/plot_WDA.py
@@ -4,60 +4,124 @@
Wasserstein Discriminant Analysis
=================================
-@author: rflamary
+This example illustrate the use of WDA as proposed in [11].
+
+
+[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+Wasserstein Discriminant Analysis.
+
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
-import ot
-from ot.datasets import get_1D_gauss as gauss
-from ot.dr import wda
+from ot.dr import wda, fda
+
+
+##############################################################################
+# Generate data
+# -------------
#%% parameters
-n=1000 # nb samples in source and target datasets
-nz=0.2
-xs,ys=ot.datasets.get_data_classif('3gauss',n,nz)
-xt,yt=ot.datasets.get_data_classif('3gauss',n,nz)
+n = 1000 # nb samples in source and target datasets
+nz = 0.2
-nbnoise=8
+# generate circle dataset
+t = np.random.rand(n) * 2 * np.pi
+ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xs = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
-xs=np.hstack((xs,np.random.randn(n,nbnoise)))
-xt=np.hstack((xt,np.random.randn(n,nbnoise)))
+t = np.random.rand(n) * 2 * np.pi
+yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+xt = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
-#%% plot samples
+nbnoise = 8
+
+xs = np.hstack((xs, np.random.randn(n, nbnoise)))
+xt = np.hstack((xt, np.random.randn(n, nbnoise)))
-pl.figure(1)
+##############################################################################
+# Plot data
+# ---------
+#%% plot samples
+pl.figure(1, figsize=(6.4, 3.5))
-pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
+pl.subplot(1, 2, 1)
+pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
pl.legend(loc=0)
pl.title('Discriminant dimensions')
+pl.subplot(1, 2, 2)
+pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
+pl.legend(loc=0)
+pl.title('Other dimensions')
+pl.tight_layout()
+
+##############################################################################
+# Compute Fisher Discriminant Analysis
+# ------------------------------------
-#%% plot distributions and loss matrix
-p=2
-reg=1
-k=10
-maxiter=100
+#%% Compute FDA
+p = 2
-P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
+Pfda, projfda = fda(xs, ys, p)
+
+##############################################################################
+# Compute Wasserstein Discriminant Analysis
+# -----------------------------------------
+
+#%% Compute WDA
+p = 2
+reg = 1e0
+k = 10
+maxiter = 100
+
+Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
+
+
+##############################################################################
+# Plot 2D projections
+# -------------------
#%% plot samples
-xsp=proj(xs)
-xtp=proj(xt)
+xsp = projfda(xs)
+xtp = projfda(xt)
-pl.figure(1,(10,5))
+xspw = projwda(xs)
+xtpw = projwda(xt)
-pl.subplot(1,2,1)
-pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
+pl.figure(2)
+
+pl.subplot(2, 2, 1)
+pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
-pl.title('Projected training samples')
+pl.title('Projected training samples FDA')
+pl.subplot(2, 2, 2)
+pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected test samples FDA')
-pl.subplot(1,2,2)
-pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
+pl.subplot(2, 2, 3)
+pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
-pl.title('Projected test samples')
+pl.title('Projected training samples WDA')
+
+pl.subplot(2, 2, 4)
+pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
+pl.legend(loc=0)
+pl.title('Projected test samples WDA')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_WDA.rst b/docs/source/auto_examples/plot_WDA.rst
index 540555d..2d83123 100644
--- a/docs/source/auto_examples/plot_WDA.rst
+++ b/docs/source/auto_examples/plot_WDA.rst
@@ -7,108 +7,222 @@
Wasserstein Discriminant Analysis
=================================
-@author: rflamary
+This example illustrate the use of WDA as proposed in [11].
+[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
+Wasserstein Discriminant Analysis.
-.. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png
- :align: center
-.. rst-class:: sphx-glr-script-out
+.. code-block:: python
- Out::
- Compiling cost function...
- Computing gradient of cost function...
- iter cost val grad. norm
- 1 +5.2427396265941129e-01 8.16627951e-01
- 2 +1.7904850059627236e-01 1.91366819e-01
- 3 +1.6985797253002377e-01 1.70940682e-01
- 4 +1.3903474972292729e-01 1.28606342e-01
- 5 +7.4961734618782416e-02 6.41973980e-02
- 6 +7.1900245222486239e-02 4.25693592e-02
- 7 +7.0472023318269614e-02 2.34599232e-02
- 8 +6.9917568641317152e-02 5.66542766e-03
- 9 +6.9885086242452696e-02 4.05756115e-04
- 10 +6.9884967432653489e-02 2.16836017e-04
- 11 +6.9884923649884148e-02 5.74961622e-05
- 12 +6.9884921818258436e-02 3.83257203e-05
- 13 +6.9884920459612282e-02 9.97486224e-06
- 14 +6.9884920414414409e-02 7.33567875e-06
- 15 +6.9884920388431387e-02 5.23889187e-06
- 16 +6.9884920385183902e-02 4.91959084e-06
- 17 +6.9884920373983223e-02 3.56451669e-06
- 18 +6.9884920369701245e-02 2.88858709e-06
- 19 +6.9884920361621208e-02 1.82294279e-07
- Terminated - min grad norm reached after 19 iterations, 9.65 seconds.
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
+ import numpy as np
+ import matplotlib.pylab as pl
+ from ot.dr import wda, fda
-|
-.. code-block:: python
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- from ot.datasets import get_1D_gauss as gauss
- from ot.dr import wda
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
#%% parameters
- n=1000 # nb samples in source and target datasets
- nz=0.2
- xs,ys=ot.datasets.get_data_classif('3gauss',n,nz)
- xt,yt=ot.datasets.get_data_classif('3gauss',n,nz)
+ n = 1000 # nb samples in source and target datasets
+ nz = 0.2
+
+ # generate circle dataset
+ t = np.random.rand(n) * 2 * np.pi
+ ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ xs = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+ xs = xs * ys.reshape(-1, 1) + nz * np.random.randn(n, 2)
+
+ t = np.random.rand(n) * 2 * np.pi
+ yt = np.floor((np.arange(n) * 1.0 / n * 3)) + 1
+ xt = np.concatenate(
+ (np.cos(t).reshape((-1, 1)), np.sin(t).reshape((-1, 1))), 1)
+ xt = xt * yt.reshape(-1, 1) + nz * np.random.randn(n, 2)
+
+ nbnoise = 8
+
+ xs = np.hstack((xs, np.random.randn(n, nbnoise)))
+ xt = np.hstack((xt, np.random.randn(n, nbnoise)))
- nbnoise=8
- xs=np.hstack((xs,np.random.randn(n,nbnoise)))
- xt=np.hstack((xt,np.random.randn(n,nbnoise)))
- #%% plot samples
- pl.figure(1)
- pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% plot samples
+ pl.figure(1, figsize=(6.4, 3.5))
+
+ pl.subplot(1, 2, 1)
+ pl.scatter(xt[:, 0], xt[:, 1], c=ys, marker='+', label='Source samples')
pl.legend(loc=0)
pl.title('Discriminant dimensions')
+ pl.subplot(1, 2, 2)
+ pl.scatter(xt[:, 2], xt[:, 3], c=ys, marker='+', label='Source samples')
+ pl.legend(loc=0)
+ pl.title('Other dimensions')
+ pl.tight_layout()
+
+
- #%% plot distributions and loss matrix
- p=2
- reg=1
- k=10
- maxiter=100
- P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
+.. image:: /auto_examples/images/sphx_glr_plot_WDA_001.png
+ :align: center
+
+
+
+
+Compute Fisher Discriminant Analysis
+------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Compute FDA
+ p = 2
+
+ Pfda, projfda = fda(xs, ys, p)
+
+
+
+
+
+
+
+Compute Wasserstein Discriminant Analysis
+-----------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Compute WDA
+ p = 2
+ reg = 1e0
+ k = 10
+ maxiter = 100
+
+ Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter)
+
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ Compiling cost function...
+ Computing gradient of cost function...
+ iter cost val grad. norm
+ 1 +9.0167295050534191e-01 2.28422652e-01
+ 2 +4.8324990550878105e-01 4.89362707e-01
+ 3 +3.4613154515357075e-01 2.84117562e-01
+ 4 +2.5277108387195002e-01 1.24888750e-01
+ 5 +2.4113858393736629e-01 8.07491482e-02
+ 6 +2.3642108593032782e-01 1.67612140e-02
+ 7 +2.3625721372202199e-01 7.68640008e-03
+ 8 +2.3625461994913738e-01 7.42200784e-03
+ 9 +2.3624493441436939e-01 6.43534105e-03
+ 10 +2.3621901383686217e-01 2.17960585e-03
+ 11 +2.3621854258326572e-01 2.03306749e-03
+ 12 +2.3621696458678049e-01 1.37118721e-03
+ 13 +2.3621569489873540e-01 2.76368907e-04
+ 14 +2.3621565599232983e-01 1.41898134e-04
+ 15 +2.3621564465487518e-01 5.96602069e-05
+ 16 +2.3621564232556647e-01 1.08709521e-05
+ 17 +2.3621564230277003e-01 9.17855656e-06
+ 18 +2.3621564224857586e-01 1.73728345e-06
+ 19 +2.3621564224748123e-01 1.17770019e-06
+ 20 +2.3621564224658587e-01 2.16179383e-07
+ Terminated - min grad norm reached after 20 iterations, 9.20 seconds.
+
+
+Plot 2D projections
+-------------------
+
+
+
+.. code-block:: python
+
#%% plot samples
- xsp=proj(xs)
- xtp=proj(xt)
+ xsp = projfda(xs)
+ xtp = projfda(xt)
+
+ xspw = projwda(xs)
+ xtpw = projwda(xt)
- pl.figure(1,(10,5))
+ pl.figure(2)
- pl.subplot(1,2,1)
- pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
+ pl.subplot(2, 2, 1)
+ pl.scatter(xsp[:, 0], xsp[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
- pl.title('Projected training samples')
+ pl.title('Projected training samples FDA')
+ pl.subplot(2, 2, 2)
+ pl.scatter(xtp[:, 0], xtp[:, 1], c=ys, marker='+', label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected test samples FDA')
+
+ pl.subplot(2, 2, 3)
+ pl.scatter(xspw[:, 0], xspw[:, 1], c=ys, marker='+', label='Projected samples')
+ pl.legend(loc=0)
+ pl.title('Projected training samples WDA')
- pl.subplot(1,2,2)
- pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
+ pl.subplot(2, 2, 4)
+ pl.scatter(xtpw[:, 0], xtpw[:, 1], c=ys, marker='+', label='Projected samples')
pl.legend(loc=0)
- pl.title('Projected test samples')
+ pl.title('Projected test samples WDA')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_WDA_003.png
+ :align: center
+
+
+
-**Total running time of the script:** ( 0 minutes 16.902 seconds)
+**Total running time of the script:** ( 0 minutes 16.182 seconds)
diff --git a/docs/source/auto_examples/plot_barycenter_1D.ipynb b/docs/source/auto_examples/plot_barycenter_1D.ipynb
index 36f3975..a19e0fd 100644
--- a/docs/source/auto_examples/plot_barycenter_1D.ipynb
+++ b/docs/source/auto_examples/plot_barycenter_1D.ipynb
@@ -15,7 +15,7 @@
},
{
"source": [
- "\n# 1D Wasserstein barycenter demo\n\n\n\n@author: rflamary\n\n"
+ "\n# 1D Wasserstein barycenter demo\n\n\nThis example illustrates the computation of regularized Wassersyein Barycenter\nas proposed in [3].\n\n\n[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyr\u00e9, G. (2015).\nIterative Bregman projections for regularized transportation problems\nSIAM Journal on Scientific Computing, 37(2), A1111-A1138.\n\n\n"
],
"cell_type": "markdown",
"metadata": {}
@@ -24,7 +24,79 @@
"execution_count": null,
"cell_type": "code",
"source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used\nfrom matplotlib.collections import PolyCollection\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std\na2=ot.datasets.get_1D_gauss(n,m=60,s=8)\n\n# creating matrix A containing all distributions\nA=np.vstack((a1,a2)).T\nnbd=A.shape[1]\n\n# loss matrix + normalization\nM=ot.utils.dist0(n)\nM/=M.max()\n\n#%% plot the distributions\n\npl.figure(1)\nfor i in range(nbd):\n pl.plot(x,A[:,i])\npl.title('Distributions')\n\n#%% barycenter computation\n\nalpha=0.2 # 0<=alpha<=1\nweights=np.array([1-alpha,alpha])\n\n# l2bary\nbary_l2=A.dot(weights)\n\n# wasserstein\nreg=1e-3\nbary_wass=ot.bregman.barycenter(A,M,reg,weights)\n\npl.figure(2)\npl.clf()\npl.subplot(2,1,1)\nfor i in range(nbd):\n pl.plot(x,A[:,i])\npl.title('Distributions')\n\npl.subplot(2,1,2)\npl.plot(x,bary_l2,'r',label='l2')\npl.plot(x,bary_wass,'g',label='Wasserstein')\npl.legend()\npl.title('Barycenters')\n\n\n#%% barycenter interpolation\n\nnbalpha=11\nalphalist=np.linspace(0,1,nbalpha)\n\n\nB_l2=np.zeros((n,nbalpha))\n\nB_wass=np.copy(B_l2)\n\nfor i in range(0,nbalpha):\n alpha=alphalist[i]\n weights=np.array([1-alpha,alpha])\n B_l2[:,i]=A.dot(weights)\n B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)\n\n#%% plot interpolation\n\npl.figure(3,(10,5))\n\n#pl.subplot(1,2,1)\ncmap=pl.cm.get_cmap('viridis')\nverts = []\nzs = alphalist\nfor i,z in enumerate(zs):\n ys = B_l2[:,i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\n\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel('$\\\\alpha$')\nax.set_ylim3d(0,1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max()*1.01)\npl.title('Barycenter interpolation with l2')\n\npl.show()\n\npl.figure(4,(10,5))\n\n#pl.subplot(1,2,1)\ncmap=pl.cm.get_cmap('viridis')\nverts = []\nzs = alphalist\nfor i,z in enumerate(zs):\n ys = B_wass[:,i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\n\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel('$\\\\alpha$')\nax.set_ylim3d(0,1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max()*1.01)\npl.title('Barycenter interpolation with Wasserstein')\n\npl.show()"
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.get_1D_gauss(n, m=60, s=8)\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot data\n---------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Barycenter computation\n----------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% barycenter computation\n\nalpha = 0.2 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\n\npl.figure(2)\npl.clf()\npl.subplot(2, 1, 1)\nfor i in range(n_distributions):\n pl.plot(x, A[:, i])\npl.title('Distributions')\n\npl.subplot(2, 1, 2)\npl.plot(x, bary_l2, 'r', label='l2')\npl.plot(x, bary_wass, 'g', label='Wasserstein')\npl.legend()\npl.title('Barycenters')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Barycentric interpolation\n-------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% barycenter interpolation\n\nn_alpha = 11\nalpha_list = np.linspace(0, 1, n_alpha)\n\n\nB_l2 = np.zeros((n, n_alpha))\n\nB_wass = np.copy(B_l2)\n\nfor i in range(0, n_alpha):\n alpha = alpha_list[i]\n weights = np.array([1 - alpha, alpha])\n B_l2[:, i] = A.dot(weights)\n B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)\n\n#%% plot interpolation\n\npl.figure(3)\n\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = alpha_list\nfor i, z in enumerate(zs):\n ys = B_l2[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel('$\\\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with l2')\npl.tight_layout()\n\npl.figure(4)\ncmap = pl.cm.get_cmap('viridis')\nverts = []\nzs = alpha_list\nfor i, z in enumerate(zs):\n ys = B_wass[:, i]\n verts.append(list(zip(x, ys)))\n\nax = pl.gcf().gca(projection='3d')\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir='y')\nax.set_xlabel('x')\nax.set_xlim3d(0, n)\nax.set_ylabel('$\\\\alpha$')\nax.set_ylim3d(0, 1)\nax.set_zlabel('')\nax.set_zlim3d(0, B_l2.max() * 1.01)\npl.title('Barycenter interpolation with Wasserstein')\npl.tight_layout()\n\npl.show()"
],
"outputs": [],
"metadata": {
diff --git a/docs/source/auto_examples/plot_barycenter_1D.py b/docs/source/auto_examples/plot_barycenter_1D.py
index 30eecbf..620936b 100644
--- a/docs/source/auto_examples/plot_barycenter_1D.py
+++ b/docs/source/auto_examples/plot_barycenter_1D.py
@@ -4,135 +4,157 @@
1D Wasserstein barycenter demo
==============================
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [3].
+
+
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
-@author: rflamary
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
import ot
-from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
+# necessary for 3d plot even if not used
+from mpl_toolkits.mplot3d import Axes3D # noqa
from matplotlib.collections import PolyCollection
+##############################################################################
+# Generate data
+# -------------
#%% parameters
-n=100 # nb bins
+n = 100 # nb bins
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
# Gaussian distributions
-a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
-a2=ot.datasets.get_1D_gauss(n,m=60,s=8)
+a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
+a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)
# creating matrix A containing all distributions
-A=np.vstack((a1,a2)).T
-nbd=A.shape[1]
+A = np.vstack((a1, a2)).T
+n_distributions = A.shape[1]
# loss matrix + normalization
-M=ot.utils.dist0(n)
-M/=M.max()
+M = ot.utils.dist0(n)
+M /= M.max()
+
+##############################################################################
+# Plot data
+# ---------
#%% plot the distributions
-pl.figure(1)
-for i in range(nbd):
- pl.plot(x,A[:,i])
+pl.figure(1, figsize=(6.4, 3))
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
pl.title('Distributions')
+pl.tight_layout()
+
+##############################################################################
+# Barycenter computation
+# ----------------------
#%% barycenter computation
-alpha=0.2 # 0<=alpha<=1
-weights=np.array([1-alpha,alpha])
+alpha = 0.2 # 0<=alpha<=1
+weights = np.array([1 - alpha, alpha])
# l2bary
-bary_l2=A.dot(weights)
+bary_l2 = A.dot(weights)
# wasserstein
-reg=1e-3
-bary_wass=ot.bregman.barycenter(A,M,reg,weights)
+reg = 1e-3
+bary_wass = ot.bregman.barycenter(A, M, reg, weights)
pl.figure(2)
pl.clf()
-pl.subplot(2,1,1)
-for i in range(nbd):
- pl.plot(x,A[:,i])
+pl.subplot(2, 1, 1)
+for i in range(n_distributions):
+ pl.plot(x, A[:, i])
pl.title('Distributions')
-pl.subplot(2,1,2)
-pl.plot(x,bary_l2,'r',label='l2')
-pl.plot(x,bary_wass,'g',label='Wasserstein')
+pl.subplot(2, 1, 2)
+pl.plot(x, bary_l2, 'r', label='l2')
+pl.plot(x, bary_wass, 'g', label='Wasserstein')
pl.legend()
pl.title('Barycenters')
+pl.tight_layout()
+##############################################################################
+# Barycentric interpolation
+# -------------------------
#%% barycenter interpolation
-nbalpha=11
-alphalist=np.linspace(0,1,nbalpha)
+n_alpha = 11
+alpha_list = np.linspace(0, 1, n_alpha)
-B_l2=np.zeros((n,nbalpha))
+B_l2 = np.zeros((n, n_alpha))
-B_wass=np.copy(B_l2)
+B_wass = np.copy(B_l2)
-for i in range(0,nbalpha):
- alpha=alphalist[i]
- weights=np.array([1-alpha,alpha])
- B_l2[:,i]=A.dot(weights)
- B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)
+for i in range(0, n_alpha):
+ alpha = alpha_list[i]
+ weights = np.array([1 - alpha, alpha])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
#%% plot interpolation
-pl.figure(3,(10,5))
+pl.figure(3)
-#pl.subplot(1,2,1)
-cmap=pl.cm.get_cmap('viridis')
+cmap = pl.cm.get_cmap('viridis')
verts = []
-zs = alphalist
-for i,z in enumerate(zs):
- ys = B_l2[:,i]
+zs = alpha_list
+for i, z in enumerate(zs):
+ ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
ax = pl.gcf().gca(projection='3d')
-poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
-
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
-ax.set_ylim3d(0,1)
+ax.set_ylim3d(0, 1)
ax.set_zlabel('')
-ax.set_zlim3d(0, B_l2.max()*1.01)
+ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with l2')
+pl.tight_layout()
-pl.show()
-
-pl.figure(4,(10,5))
-
-#pl.subplot(1,2,1)
-cmap=pl.cm.get_cmap('viridis')
+pl.figure(4)
+cmap = pl.cm.get_cmap('viridis')
verts = []
-zs = alphalist
-for i,z in enumerate(zs):
- ys = B_wass[:,i]
+zs = alpha_list
+for i, z in enumerate(zs):
+ ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
ax = pl.gcf().gca(projection='3d')
-poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
-
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
-ax.set_ylim3d(0,1)
+ax.set_ylim3d(0, 1)
ax.set_zlabel('')
-ax.set_zlim3d(0, B_l2.max()*1.01)
+ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with Wasserstein')
+pl.tight_layout()
-pl.show() \ No newline at end of file
+pl.show()
diff --git a/docs/source/auto_examples/plot_barycenter_1D.rst b/docs/source/auto_examples/plot_barycenter_1D.rst
index 1b15c77..f17f2c2 100644
--- a/docs/source/auto_examples/plot_barycenter_1D.rst
+++ b/docs/source/auto_examples/plot_barycenter_1D.rst
@@ -7,171 +7,230 @@
1D Wasserstein barycenter demo
==============================
+This example illustrates the computation of regularized Wassersyein Barycenter
+as proposed in [3].
-@author: rflamary
+[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
+Iterative Bregman projections for regularized transportation problems
+SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
-.. rst-class:: sphx-glr-horizontal
+.. code-block:: python
- *
- .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
- :scale: 47
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
- *
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ # necessary for 3d plot even if not used
+ from mpl_toolkits.mplot3d import Axes3D # noqa
+ from matplotlib.collections import PolyCollection
- .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_002.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_004.png
- :scale: 47
+Generate data
+-------------
.. code-block:: python
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- from mpl_toolkits.mplot3d import Axes3D #necessary for 3d plot even if not used
- from matplotlib.collections import PolyCollection
-
-
#%% parameters
- n=100 # nb bins
+ n = 100 # nb bins
# bin positions
- x=np.arange(n,dtype=np.float64)
+ x = np.arange(n, dtype=np.float64)
# Gaussian distributions
- a1=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
- a2=ot.datasets.get_1D_gauss(n,m=60,s=8)
+ a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)
# creating matrix A containing all distributions
- A=np.vstack((a1,a2)).T
- nbd=A.shape[1]
+ A = np.vstack((a1, a2)).T
+ n_distributions = A.shape[1]
# loss matrix + normalization
- M=ot.utils.dist0(n)
- M/=M.max()
+ M = ot.utils.dist0(n)
+ M /= M.max()
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
#%% plot the distributions
- pl.figure(1)
- for i in range(nbd):
- pl.plot(x,A[:,i])
+ pl.figure(1, figsize=(6.4, 3))
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
pl.title('Distributions')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_001.png
+ :align: center
+
+
+
+
+Barycenter computation
+----------------------
+
+
+
+.. code-block:: python
+
#%% barycenter computation
- alpha=0.2 # 0<=alpha<=1
- weights=np.array([1-alpha,alpha])
+ alpha = 0.2 # 0<=alpha<=1
+ weights = np.array([1 - alpha, alpha])
# l2bary
- bary_l2=A.dot(weights)
+ bary_l2 = A.dot(weights)
# wasserstein
- reg=1e-3
- bary_wass=ot.bregman.barycenter(A,M,reg,weights)
+ reg = 1e-3
+ bary_wass = ot.bregman.barycenter(A, M, reg, weights)
pl.figure(2)
pl.clf()
- pl.subplot(2,1,1)
- for i in range(nbd):
- pl.plot(x,A[:,i])
+ pl.subplot(2, 1, 1)
+ for i in range(n_distributions):
+ pl.plot(x, A[:, i])
pl.title('Distributions')
- pl.subplot(2,1,2)
- pl.plot(x,bary_l2,'r',label='l2')
- pl.plot(x,bary_wass,'g',label='Wasserstein')
+ pl.subplot(2, 1, 2)
+ pl.plot(x, bary_l2, 'r', label='l2')
+ pl.plot(x, bary_wass, 'g', label='Wasserstein')
pl.legend()
pl.title('Barycenters')
+ pl.tight_layout()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_003.png
+ :align: center
+
+
+
+
+Barycentric interpolation
+-------------------------
+
+
+
+.. code-block:: python
#%% barycenter interpolation
- nbalpha=11
- alphalist=np.linspace(0,1,nbalpha)
+ n_alpha = 11
+ alpha_list = np.linspace(0, 1, n_alpha)
- B_l2=np.zeros((n,nbalpha))
+ B_l2 = np.zeros((n, n_alpha))
- B_wass=np.copy(B_l2)
+ B_wass = np.copy(B_l2)
- for i in range(0,nbalpha):
- alpha=alphalist[i]
- weights=np.array([1-alpha,alpha])
- B_l2[:,i]=A.dot(weights)
- B_wass[:,i]=ot.bregman.barycenter(A,M,reg,weights)
+ for i in range(0, n_alpha):
+ alpha = alpha_list[i]
+ weights = np.array([1 - alpha, alpha])
+ B_l2[:, i] = A.dot(weights)
+ B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)
#%% plot interpolation
- pl.figure(3,(10,5))
+ pl.figure(3)
- #pl.subplot(1,2,1)
- cmap=pl.cm.get_cmap('viridis')
+ cmap = pl.cm.get_cmap('viridis')
verts = []
- zs = alphalist
- for i,z in enumerate(zs):
- ys = B_l2[:,i]
+ zs = alpha_list
+ for i, z in enumerate(zs):
+ ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
ax = pl.gcf().gca(projection='3d')
- poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
-
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
- ax.set_ylim3d(0,1)
+ ax.set_ylim3d(0, 1)
ax.set_zlabel('')
- ax.set_zlim3d(0, B_l2.max()*1.01)
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with l2')
+ pl.tight_layout()
- pl.show()
-
- pl.figure(4,(10,5))
-
- #pl.subplot(1,2,1)
- cmap=pl.cm.get_cmap('viridis')
+ pl.figure(4)
+ cmap = pl.cm.get_cmap('viridis')
verts = []
- zs = alphalist
- for i,z in enumerate(zs):
- ys = B_wass[:,i]
+ zs = alpha_list
+ for i, z in enumerate(zs):
+ ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
ax = pl.gcf().gca(projection='3d')
- poly = PolyCollection(verts,facecolors=[cmap(a) for a in alphalist])
+ poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
-
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
- ax.set_ylim3d(0,1)
+ ax.set_ylim3d(0, 1)
ax.set_zlabel('')
- ax.set_zlim3d(0, B_l2.max()*1.01)
+ ax.set_zlim3d(0, B_l2.max() * 1.01)
pl.title('Barycenter interpolation with Wasserstein')
+ pl.tight_layout()
pl.show()
-**Total running time of the script:** ( 0 minutes 2.274 seconds)
+
+
+
+.. rst-class:: sphx-glr-horizontal
+
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_005.png
+ :scale: 47
+
+ *
+
+ .. image:: /auto_examples/images/sphx_glr_plot_barycenter_1D_006.png
+ :scale: 47
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.814 seconds)
diff --git a/docs/source/auto_examples/plot_compute_emd.ipynb b/docs/source/auto_examples/plot_compute_emd.ipynb
index 4162144..b9b8bc5 100644
--- a/docs/source/auto_examples/plot_compute_emd.ipynb
+++ b/docs/source/auto_examples/plot_compute_emd.ipynb
@@ -15,7 +15,7 @@
},
{
"source": [
- "\n# 1D optimal transport\n\n\n@author: rflamary\n\n"
+ "\n# Plot multiple EMD\n\n\nShows how to compute multiple EMD and Sinkhorn with two differnt\nground metrics and plot their values for diffeent distributions.\n\n\n\n"
],
"cell_type": "markdown",
"metadata": {}
@@ -24,7 +24,79 @@
"execution_count": null,
"cell_type": "code",
"source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss\n\n\n#%% parameters\n\nn=100 # nb bins\nn_target=50 # nb target distributions\n\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\nlst_m=np.linspace(20,90,n_target)\n\n# Gaussian distributions\na=gauss(n,m=20,s=5) # m= mean, s= std\n\nB=np.zeros((n,n_target))\n\nfor i,m in enumerate(lst_m):\n B[:,i]=gauss(n,m=m,s=5)\n\n# loss matrix and normalization\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')\nM/=M.max()\nM2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')\nM2/=M2.max()\n#%% plot the distributions\n\npl.figure(1)\npl.subplot(2,1,1)\npl.plot(x,a,'b',label='Source distribution')\npl.title('Source distribution')\npl.subplot(2,1,2)\npl.plot(x,B,label='Target distributions')\npl.title('Target distributions')\n\n#%% Compute and plot distributions and loss matrix\n\nd_emd=ot.emd2(a,B,M) # direct computation of EMD\nd_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3\n\n\npl.figure(2)\npl.plot(d_emd,label='Euclidean EMD')\npl.plot(d_emd2,label='Squared Euclidean EMD')\npl.title('EMD distances')\npl.legend()\n\n#%%\nreg=1e-2\nd_sinkhorn=ot.sinkhorn(a,B,M,reg)\nd_sinkhorn2=ot.sinkhorn(a,B,M2,reg)\n\npl.figure(2)\npl.clf()\npl.plot(d_emd,label='Euclidean EMD')\npl.plot(d_emd2,label='Squared Euclidean EMD')\npl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn')\npl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn')\npl.title('EMD distances')\npl.legend()"
+ "# Author: Remi Flamary <remi.flamary@unice.fr>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import get_1D_gauss as gauss"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\nn_target = 50 # nb target distributions\n\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\nlst_m = np.linspace(20, 90, n_target)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\n\nB = np.zeros((n, n_target))\n\nfor i, m in enumerate(lst_m):\n B[:, i] = gauss(n, m=m, s=5)\n\n# loss matrix and normalization\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')\nM /= M.max()\nM2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')\nM2 /= M2.max()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot data\n---------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% plot the distributions\n\npl.figure(1)\npl.subplot(2, 1, 1)\npl.plot(x, a, 'b', label='Source distribution')\npl.title('Source distribution')\npl.subplot(2, 1, 2)\npl.plot(x, B, label='Target distributions')\npl.title('Target distributions')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute EMD for the different losses\n------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% Compute and plot distributions and loss matrix\n\nd_emd = ot.emd2(a, B, M) # direct computation of EMD\nd_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2\n\n\npl.figure(2)\npl.plot(d_emd, label='Euclidean EMD')\npl.plot(d_emd2, label='Squared Euclidean EMD')\npl.title('EMD distances')\npl.legend()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute Sinkhorn for the different losses\n-----------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%%\nreg = 1e-2\nd_sinkhorn = ot.sinkhorn2(a, B, M, reg)\nd_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)\n\npl.figure(2)\npl.clf()\npl.plot(d_emd, label='Euclidean EMD')\npl.plot(d_emd2, label='Squared Euclidean EMD')\npl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')\npl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')\npl.title('EMD distances')\npl.legend()\n\npl.show()"
],
"outputs": [],
"metadata": {
diff --git a/docs/source/auto_examples/plot_compute_emd.py b/docs/source/auto_examples/plot_compute_emd.py
index c7063e8..73b42c3 100644
--- a/docs/source/auto_examples/plot_compute_emd.py
+++ b/docs/source/auto_examples/plot_compute_emd.py
@@ -1,74 +1,102 @@
# -*- coding: utf-8 -*-
"""
-====================
-1D optimal transport
-====================
+=================
+Plot multiple EMD
+=================
+
+Shows how to compute multiple EMD and Sinkhorn with two differnt
+ground metrics and plot their values for diffeent distributions.
+
-@author: rflamary
"""
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
import numpy as np
import matplotlib.pylab as pl
import ot
from ot.datasets import get_1D_gauss as gauss
+##############################################################################
+# Generate data
+# -------------
+
#%% parameters
-n=100 # nb bins
-n_target=50 # nb target distributions
+n = 100 # nb bins
+n_target = 50 # nb target distributions
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
-lst_m=np.linspace(20,90,n_target)
+lst_m = np.linspace(20, 90, n_target)
# Gaussian distributions
-a=gauss(n,m=20,s=5) # m= mean, s= std
+a = gauss(n, m=20, s=5) # m= mean, s= std
-B=np.zeros((n,n_target))
+B = np.zeros((n, n_target))
-for i,m in enumerate(lst_m):
- B[:,i]=gauss(n,m=m,s=5)
+for i, m in enumerate(lst_m):
+ B[:, i] = gauss(n, m=m, s=5)
# loss matrix and normalization
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')
-M/=M.max()
-M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')
-M2/=M2.max()
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
+M /= M.max()
+M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
+M2 /= M2.max()
+
+##############################################################################
+# Plot data
+# ---------
+
#%% plot the distributions
pl.figure(1)
-pl.subplot(2,1,1)
-pl.plot(x,a,'b',label='Source distribution')
+pl.subplot(2, 1, 1)
+pl.plot(x, a, 'b', label='Source distribution')
pl.title('Source distribution')
-pl.subplot(2,1,2)
-pl.plot(x,B,label='Target distributions')
+pl.subplot(2, 1, 2)
+pl.plot(x, B, label='Target distributions')
pl.title('Target distributions')
+pl.tight_layout()
+
+
+##############################################################################
+# Compute EMD for the different losses
+# ------------------------------------
#%% Compute and plot distributions and loss matrix
-d_emd=ot.emd2(a,B,M) # direct computation of EMD
-d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3
+d_emd = ot.emd2(a, B, M) # direct computation of EMD
+d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
pl.figure(2)
-pl.plot(d_emd,label='Euclidean EMD')
-pl.plot(d_emd2,label='Squared Euclidean EMD')
+pl.plot(d_emd, label='Euclidean EMD')
+pl.plot(d_emd2, label='Squared Euclidean EMD')
pl.title('EMD distances')
pl.legend()
+##############################################################################
+# Compute Sinkhorn for the different losses
+# -----------------------------------------
+
#%%
-reg=1e-2
-d_sinkhorn=ot.sinkhorn(a,B,M,reg)
-d_sinkhorn2=ot.sinkhorn(a,B,M2,reg)
+reg = 1e-2
+d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
+d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
pl.figure(2)
pl.clf()
-pl.plot(d_emd,label='Euclidean EMD')
-pl.plot(d_emd2,label='Squared Euclidean EMD')
-pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn')
-pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn')
+pl.plot(d_emd, label='Euclidean EMD')
+pl.plot(d_emd2, label='Squared Euclidean EMD')
+pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
+pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
pl.title('EMD distances')
-pl.legend() \ No newline at end of file
+pl.legend()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_compute_emd.rst b/docs/source/auto_examples/plot_compute_emd.rst
index 4c7445b..cdbc620 100644
--- a/docs/source/auto_examples/plot_compute_emd.rst
+++ b/docs/source/auto_examples/plot_compute_emd.rst
@@ -3,101 +3,166 @@
.. _sphx_glr_auto_examples_plot_compute_emd.py:
-====================
-1D optimal transport
-====================
+=================
+Plot multiple EMD
+=================
-@author: rflamary
+Shows how to compute multiple EMD and Sinkhorn with two differnt
+ground metrics and plot their values for diffeent distributions.
-.. rst-class:: sphx-glr-horizontal
+.. code-block:: python
- *
- .. image:: /auto_examples/images/sphx_glr_plot_compute_emd_001.png
- :scale: 47
+ # Author: Remi Flamary <remi.flamary@unice.fr>
+ #
+ # License: MIT License
- *
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+ from ot.datasets import get_1D_gauss as gauss
- .. image:: /auto_examples/images/sphx_glr_plot_compute_emd_002.png
- :scale: 47
-.. code-block:: python
- import numpy as np
- import matplotlib.pylab as pl
- import ot
- from ot.datasets import get_1D_gauss as gauss
+Generate data
+-------------
+
+
+
+.. code-block:: python
#%% parameters
- n=100 # nb bins
- n_target=50 # nb target distributions
+ n = 100 # nb bins
+ n_target = 50 # nb target distributions
# bin positions
- x=np.arange(n,dtype=np.float64)
+ x = np.arange(n, dtype=np.float64)
- lst_m=np.linspace(20,90,n_target)
+ lst_m = np.linspace(20, 90, n_target)
# Gaussian distributions
- a=gauss(n,m=20,s=5) # m= mean, s= std
+ a = gauss(n, m=20, s=5) # m= mean, s= std
- B=np.zeros((n,n_target))
+ B = np.zeros((n, n_target))
- for i,m in enumerate(lst_m):
- B[:,i]=gauss(n,m=m,s=5)
+ for i, m in enumerate(lst_m):
+ B[:, i] = gauss(n, m=m, s=5)
# loss matrix and normalization
- M=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'euclidean')
- M/=M.max()
- M2=ot.dist(x.reshape((n,1)),x.reshape((n,1)),'sqeuclidean')
- M2/=M2.max()
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'euclidean')
+ M /= M.max()
+ M2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), 'sqeuclidean')
+ M2 /= M2.max()
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
#%% plot the distributions
pl.figure(1)
- pl.subplot(2,1,1)
- pl.plot(x,a,'b',label='Source distribution')
+ pl.subplot(2, 1, 1)
+ pl.plot(x, a, 'b', label='Source distribution')
pl.title('Source distribution')
- pl.subplot(2,1,2)
- pl.plot(x,B,label='Target distributions')
+ pl.subplot(2, 1, 2)
+ pl.plot(x, B, label='Target distributions')
pl.title('Target distributions')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_compute_emd_001.png
+ :align: center
+
+
+
+
+Compute EMD for the different losses
+------------------------------------
+
+
+
+.. code-block:: python
+
#%% Compute and plot distributions and loss matrix
- d_emd=ot.emd2(a,B,M) # direct computation of EMD
- d_emd2=ot.emd2(a,B,M2) # direct computation of EMD with loss M3
+ d_emd = ot.emd2(a, B, M) # direct computation of EMD
+ d_emd2 = ot.emd2(a, B, M2) # direct computation of EMD with loss M2
pl.figure(2)
- pl.plot(d_emd,label='Euclidean EMD')
- pl.plot(d_emd2,label='Squared Euclidean EMD')
+ pl.plot(d_emd, label='Euclidean EMD')
+ pl.plot(d_emd2, label='Squared Euclidean EMD')
pl.title('EMD distances')
pl.legend()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_compute_emd_003.png
+ :align: center
+
+
+
+
+Compute Sinkhorn for the different losses
+-----------------------------------------
+
+
+
+.. code-block:: python
+
+
#%%
- reg=1e-2
- d_sinkhorn=ot.sinkhorn(a,B,M,reg)
- d_sinkhorn2=ot.sinkhorn(a,B,M2,reg)
+ reg = 1e-2
+ d_sinkhorn = ot.sinkhorn2(a, B, M, reg)
+ d_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)
pl.figure(2)
pl.clf()
- pl.plot(d_emd,label='Euclidean EMD')
- pl.plot(d_emd2,label='Squared Euclidean EMD')
- pl.plot(d_sinkhorn,'+',label='Euclidean Sinkhorn')
- pl.plot(d_sinkhorn2,'+',label='Squared Euclidean Sinkhorn')
+ pl.plot(d_emd, label='Euclidean EMD')
+ pl.plot(d_emd2, label='Squared Euclidean EMD')
+ pl.plot(d_sinkhorn, '+', label='Euclidean Sinkhorn')
+ pl.plot(d_sinkhorn2, '+', label='Squared Euclidean Sinkhorn')
pl.title('EMD distances')
pl.legend()
-**Total running time of the script:** ( 0 minutes 0.521 seconds)
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_compute_emd_004.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.697 seconds)
diff --git a/docs/source/auto_examples/plot_gromov.ipynb b/docs/source/auto_examples/plot_gromov.ipynb
new file mode 100644
index 0000000..865848e
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov.ipynb
@@ -0,0 +1,126 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# Gromov-Wasserstein example\n\n\nThis example is designed to show how to use the Gromov-Wassertsein distance\ncomputation in POT.\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Author: Erwan Vautier <erwan.vautier@gmail.com>\r\n# Nicolas Courty <ncourty@irisa.fr>\r\n#\r\n# License: MIT License\r\n\r\nimport scipy as sp\r\nimport numpy as np\r\nimport matplotlib.pylab as pl\r\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\r\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Sample two Gaussian distributions (2D and 3D)\r\n ---------------------------------------------\r\n\r\n The Gromov-Wasserstein distance allows to compute distances with samples that\r\n do not belong to the same metric space. For demonstration purpose, we sample\r\n two Gaussian distributions in 2- and 3-dimensional spaces.\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "n_samples = 30 # nb samples\r\n\r\nmu_s = np.array([0, 0])\r\ncov_s = np.array([[1, 0], [0, 1]])\r\n\r\nmu_t = np.array([4, 4, 4])\r\ncov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\r\n\r\n\r\nxs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)\r\nP = sp.linalg.sqrtm(cov_t)\r\nxt = np.random.randn(n_samples, 3).dot(P) + mu_t"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plotting the distributions\r\n--------------------------\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "fig = pl.figure()\r\nax1 = fig.add_subplot(121)\r\nax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')\r\nax2 = fig.add_subplot(122, projection='3d')\r\nax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')\r\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute distance kernels, normalize them and then display\r\n---------------------------------------------------------\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "C1 = sp.spatial.distance.cdist(xs, xs)\r\nC2 = sp.spatial.distance.cdist(xt, xt)\r\n\r\nC1 /= C1.max()\r\nC2 /= C2.max()\r\n\r\npl.figure()\r\npl.subplot(121)\r\npl.imshow(C1)\r\npl.subplot(122)\r\npl.imshow(C2)\r\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Compute Gromov-Wasserstein plans and distance\r\n---------------------------------------------\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "p = ot.unif(n_samples)\r\nq = ot.unif(n_samples)\r\n\r\ngw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)\r\ngw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)\r\n\r\nprint('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))\r\n\r\npl.figure()\r\npl.imshow(gw, cmap='jet')\r\npl.colorbar()\r\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_gromov.py b/docs/source/auto_examples/plot_gromov.py
new file mode 100644
index 0000000..d3f724c
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+from mpl_toolkits.mplot3d import Axes3D # noqa
+import ot
+
+
+##############################################################################
+# Sample two Gaussian distributions (2D and 3D)
+# ---------------------------------------------
+#
+# The Gromov-Wasserstein distance allows to compute distances with samples that
+# do not belong to the same metric space. For demonstration purpose, we sample
+# two Gaussian distributions in 2- and 3-dimensional spaces.
+
+
+n_samples = 30 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4, 4])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+
+
+##############################################################################
+# Plotting the distributions
+# --------------------------
+
+
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+
+##############################################################################
+# Compute distance kernels, normalize them and then display
+# ---------------------------------------------------------
+
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+C1 /= C1.max()
+C2 /= C2.max()
+
+pl.figure()
+pl.subplot(121)
+pl.imshow(C1)
+pl.subplot(122)
+pl.imshow(C2)
+pl.show()
+
+##############################################################################
+# Compute Gromov-Wasserstein plans and distance
+# ---------------------------------------------
+
+
+p = ot.unif(n_samples)
+q = ot.unif(n_samples)
+
+gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+
+print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
+
+pl.figure()
+pl.imshow(gw, cmap='jet')
+pl.colorbar()
+pl.show()
diff --git a/docs/source/auto_examples/plot_gromov.rst b/docs/source/auto_examples/plot_gromov.rst
new file mode 100644
index 0000000..65cf4e4
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov.rst
@@ -0,0 +1,180 @@
+
+
+.. _sphx_glr_auto_examples_plot_gromov.py:
+
+
+==========================
+Gromov-Wasserstein example
+==========================
+
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+
+
+
+.. code-block:: python
+
+
+ # Author: Erwan Vautier <erwan.vautier@gmail.com>
+ # Nicolas Courty <ncourty@irisa.fr>
+ #
+ # License: MIT License
+
+ import scipy as sp
+ import numpy as np
+ import matplotlib.pylab as pl
+ from mpl_toolkits.mplot3d import Axes3D # noqa
+ import ot
+
+
+
+
+
+
+
+
+Sample two Gaussian distributions (2D and 3D)
+ ---------------------------------------------
+
+ The Gromov-Wasserstein distance allows to compute distances with samples that
+ do not belong to the same metric space. For demonstration purpose, we sample
+ two Gaussian distributions in 2- and 3-dimensional spaces.
+
+
+
+.. code-block:: python
+
+
+
+ n_samples = 30 # nb samples
+
+ mu_s = np.array([0, 0])
+ cov_s = np.array([[1, 0], [0, 1]])
+
+ mu_t = np.array([4, 4, 4])
+ cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+ xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
+ P = sp.linalg.sqrtm(cov_t)
+ xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+
+
+
+
+
+
+
+
+Plotting the distributions
+--------------------------
+
+
+
+.. code-block:: python
+
+
+
+ fig = pl.figure()
+ ax1 = fig.add_subplot(121)
+ ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ ax2 = fig.add_subplot(122, projection='3d')
+ ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+ pl.show()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_gromov_001.png
+ :align: center
+
+
+
+
+Compute distance kernels, normalize them and then display
+---------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+
+ C1 = sp.spatial.distance.cdist(xs, xs)
+ C2 = sp.spatial.distance.cdist(xt, xt)
+
+ C1 /= C1.max()
+ C2 /= C2.max()
+
+ pl.figure()
+ pl.subplot(121)
+ pl.imshow(C1)
+ pl.subplot(122)
+ pl.imshow(C2)
+ pl.show()
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_gromov_002.png
+ :align: center
+
+
+
+
+Compute Gromov-Wasserstein plans and distance
+---------------------------------------------
+
+
+
+.. code-block:: python
+
+
+
+ p = ot.unif(n_samples)
+ q = ot.unif(n_samples)
+
+ gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+ gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+
+ print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
+
+ pl.figure()
+ pl.imshow(gw, cmap='jet')
+ pl.colorbar()
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_gromov_003.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ Gromov-Wasserstein distances between the distribution: 0.225058076974
+
+
+**Total running time of the script:** ( 0 minutes 4.070 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_gromov.py <plot_gromov.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_gromov.ipynb <plot_gromov.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_gromov_barycenter.ipynb b/docs/source/auto_examples/plot_gromov_barycenter.ipynb
new file mode 100644
index 0000000..d38dfbb
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov_barycenter.ipynb
@@ -0,0 +1,126 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# Gromov-Wasserstein Barycenter example\n\n\nThis example is designed to show how to use the Gromov-Wasserstein distance\ncomputation in POT.\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Author: Erwan Vautier <erwan.vautier@gmail.com>\r\n# Nicolas Courty <ncourty@irisa.fr>\r\n#\r\n# License: MIT License\r\n\r\n\r\nimport numpy as np\r\nimport scipy as sp\r\n\r\nimport scipy.ndimage as spi\r\nimport matplotlib.pylab as pl\r\nfrom sklearn import manifold\r\nfrom sklearn.decomposition import PCA\r\n\r\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Smacof MDS\r\n ----------\r\n\r\n This function allows to find an embedding of points given a dissimilarity matrix\r\n that will be given by the output of the algorithm\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "def smacof_mds(C, dim, max_iter=3000, eps=1e-9):\r\n \"\"\"\r\n Returns an interpolated point cloud following the dissimilarity matrix C\r\n using SMACOF multidimensional scaling (MDS) in specific dimensionned\r\n target space\r\n\r\n Parameters\r\n ----------\r\n C : ndarray, shape (ns, ns)\r\n dissimilarity matrix\r\n dim : int\r\n dimension of the targeted space\r\n max_iter : int\r\n Maximum number of iterations of the SMACOF algorithm for a single run\r\n eps : float\r\n relative tolerance w.r.t stress to declare converge\r\n\r\n Returns\r\n -------\r\n npos : ndarray, shape (R, dim)\r\n Embedded coordinates of the interpolated point cloud (defined with\r\n one isometry)\r\n \"\"\"\r\n\r\n rng = np.random.RandomState(seed=3)\r\n\r\n mds = manifold.MDS(\r\n dim,\r\n max_iter=max_iter,\r\n eps=1e-9,\r\n dissimilarity='precomputed',\r\n n_init=1)\r\n pos = mds.fit(C).embedding_\r\n\r\n nmds = manifold.MDS(\r\n 2,\r\n max_iter=max_iter,\r\n eps=1e-9,\r\n dissimilarity=\"precomputed\",\r\n random_state=rng,\r\n n_init=1)\r\n npos = nmds.fit_transform(C, init=pos)\r\n\r\n return npos"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Data preparation\r\n ----------------\r\n\r\n The four distributions are constructed from 4 simple images\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "def im2mat(I):\r\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\r\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\r\n\r\n\r\nsquare = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256\r\ncross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256\r\ntriangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256\r\nstar = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256\r\n\r\nshapes = [square, cross, triangle, star]\r\n\r\nS = 4\r\nxs = [[] for i in range(S)]\r\n\r\n\r\nfor nb in range(4):\r\n for i in range(8):\r\n for j in range(8):\r\n if shapes[nb][i, j] < 0.95:\r\n xs[nb].append([j, 8 - i])\r\n\r\nxs = np.array([np.array(xs[0]), np.array(xs[1]),\r\n np.array(xs[2]), np.array(xs[3])])"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Barycenter computation\r\n----------------------\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "ns = [len(xs[s]) for s in range(S)]\r\nn_samples = 30\r\n\r\n\"\"\"Compute all distances matrices for the four shapes\"\"\"\r\nCs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]\r\nCs = [cs / cs.max() for cs in Cs]\r\n\r\nps = [ot.unif(ns[s]) for s in range(S)]\r\np = ot.unif(n_samples)\r\n\r\n\r\nlambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]\r\n\r\nCt01 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],\r\n [ps[0], ps[1]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)\r\n\r\nCt02 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],\r\n [ps[0], ps[2]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)\r\n\r\nCt13 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],\r\n [ps[1], ps[3]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)\r\n\r\nCt23 = [0 for i in range(2)]\r\nfor i in range(2):\r\n Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],\r\n [ps[2], ps[3]\r\n ], p, lambdast[i], 'square_loss', 5e-4,\r\n max_iter=100, tol=1e-3)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Visualization\r\n -------------\r\n\r\n The PCA helps in getting consistency between the rotations\r\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "clf = PCA(n_components=2)\r\nnpos = [0, 0, 0, 0]\r\nnpos = [smacof_mds(Cs[s], 2) for s in range(S)]\r\n\r\nnpost01 = [0, 0]\r\nnpost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]\r\nnpost01 = [clf.fit_transform(npost01[s]) for s in range(2)]\r\n\r\nnpost02 = [0, 0]\r\nnpost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]\r\nnpost02 = [clf.fit_transform(npost02[s]) for s in range(2)]\r\n\r\nnpost13 = [0, 0]\r\nnpost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]\r\nnpost13 = [clf.fit_transform(npost13[s]) for s in range(2)]\r\n\r\nnpost23 = [0, 0]\r\nnpost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]\r\nnpost23 = [clf.fit_transform(npost23[s]) for s in range(2)]\r\n\r\n\r\nfig = pl.figure(figsize=(10, 10))\r\n\r\nax1 = pl.subplot2grid((4, 4), (0, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')\r\n\r\nax2 = pl.subplot2grid((4, 4), (0, 1))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')\r\n\r\nax3 = pl.subplot2grid((4, 4), (0, 2))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')\r\n\r\nax4 = pl.subplot2grid((4, 4), (0, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')\r\n\r\nax5 = pl.subplot2grid((4, 4), (1, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')\r\n\r\nax6 = pl.subplot2grid((4, 4), (1, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')\r\n\r\nax7 = pl.subplot2grid((4, 4), (2, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')\r\n\r\nax8 = pl.subplot2grid((4, 4), (2, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')\r\n\r\nax9 = pl.subplot2grid((4, 4), (3, 0))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')\r\n\r\nax10 = pl.subplot2grid((4, 4), (3, 1))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')\r\n\r\nax11 = pl.subplot2grid((4, 4), (3, 2))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')\r\n\r\nax12 = pl.subplot2grid((4, 4), (3, 3))\r\npl.xlim((-1, 1))\r\npl.ylim((-1, 1))\r\nax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_gromov_barycenter.py b/docs/source/auto_examples/plot_gromov_barycenter.py
new file mode 100644
index 0000000..180b0cf
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov_barycenter.py
@@ -0,0 +1,248 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================
+Gromov-Wasserstein Barycenter example
+=====================================
+
+This example is designed to show how to use the Gromov-Wasserstein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import scipy as sp
+
+import scipy.ndimage as spi
+import matplotlib.pylab as pl
+from sklearn import manifold
+from sklearn.decomposition import PCA
+
+import ot
+
+##############################################################################
+# Smacof MDS
+# ----------
+#
+# This function allows to find an embedding of points given a dissimilarity matrix
+# that will be given by the output of the algorithm
+
+
+def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
+ """
+ Returns an interpolated point cloud following the dissimilarity matrix C
+ using SMACOF multidimensional scaling (MDS) in specific dimensionned
+ target space
+
+ Parameters
+ ----------
+ C : ndarray, shape (ns, ns)
+ dissimilarity matrix
+ dim : int
+ dimension of the targeted space
+ max_iter : int
+ Maximum number of iterations of the SMACOF algorithm for a single run
+ eps : float
+ relative tolerance w.r.t stress to declare converge
+
+ Returns
+ -------
+ npos : ndarray, shape (R, dim)
+ Embedded coordinates of the interpolated point cloud (defined with
+ one isometry)
+ """
+
+ rng = np.random.RandomState(seed=3)
+
+ mds = manifold.MDS(
+ dim,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity='precomputed',
+ n_init=1)
+ pos = mds.fit(C).embedding_
+
+ nmds = manifold.MDS(
+ 2,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity="precomputed",
+ random_state=rng,
+ n_init=1)
+ npos = nmds.fit_transform(C, init=pos)
+
+ return npos
+
+
+##############################################################################
+# Data preparation
+# ----------------
+#
+# The four distributions are constructed from 4 simple images
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
+cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
+star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
+
+shapes = [square, cross, triangle, star]
+
+S = 4
+xs = [[] for i in range(S)]
+
+
+for nb in range(4):
+ for i in range(8):
+ for j in range(8):
+ if shapes[nb][i, j] < 0.95:
+ xs[nb].append([j, 8 - i])
+
+xs = np.array([np.array(xs[0]), np.array(xs[1]),
+ np.array(xs[2]), np.array(xs[3])])
+
+##############################################################################
+# Barycenter computation
+# ----------------------
+
+
+ns = [len(xs[s]) for s in range(S)]
+n_samples = 30
+
+"""Compute all distances matrices for the four shapes"""
+Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
+Cs = [cs / cs.max() for cs in Cs]
+
+ps = [ot.unif(ns[s]) for s in range(S)]
+p = ot.unif(n_samples)
+
+
+lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
+
+Ct01 = [0 for i in range(2)]
+for i in range(2):
+ Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
+ [ps[0], ps[1]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, tol=1e-3)
+
+Ct02 = [0 for i in range(2)]
+for i in range(2):
+ Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
+ [ps[0], ps[2]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, tol=1e-3)
+
+Ct13 = [0 for i in range(2)]
+for i in range(2):
+ Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
+ [ps[1], ps[3]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, tol=1e-3)
+
+Ct23 = [0 for i in range(2)]
+for i in range(2):
+ Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
+ [ps[2], ps[3]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, tol=1e-3)
+
+
+##############################################################################
+# Visualization
+# -------------
+#
+# The PCA helps in getting consistency between the rotations
+
+
+clf = PCA(n_components=2)
+npos = [0, 0, 0, 0]
+npos = [smacof_mds(Cs[s], 2) for s in range(S)]
+
+npost01 = [0, 0]
+npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
+npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
+
+npost02 = [0, 0]
+npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
+npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
+
+npost13 = [0, 0]
+npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
+npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
+
+npost23 = [0, 0]
+npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
+npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
+
+
+fig = pl.figure(figsize=(10, 10))
+
+ax1 = pl.subplot2grid((4, 4), (0, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
+
+ax2 = pl.subplot2grid((4, 4), (0, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
+
+ax3 = pl.subplot2grid((4, 4), (0, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
+
+ax4 = pl.subplot2grid((4, 4), (0, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
+
+ax5 = pl.subplot2grid((4, 4), (1, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
+
+ax6 = pl.subplot2grid((4, 4), (1, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
+
+ax7 = pl.subplot2grid((4, 4), (2, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
+
+ax8 = pl.subplot2grid((4, 4), (2, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
+
+ax9 = pl.subplot2grid((4, 4), (3, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
+
+ax10 = pl.subplot2grid((4, 4), (3, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
+
+ax11 = pl.subplot2grid((4, 4), (3, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
+
+ax12 = pl.subplot2grid((4, 4), (3, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
diff --git a/docs/source/auto_examples/plot_gromov_barycenter.rst b/docs/source/auto_examples/plot_gromov_barycenter.rst
new file mode 100644
index 0000000..ca2d4e9
--- /dev/null
+++ b/docs/source/auto_examples/plot_gromov_barycenter.rst
@@ -0,0 +1,324 @@
+
+
+.. _sphx_glr_auto_examples_plot_gromov_barycenter.py:
+
+
+=====================================
+Gromov-Wasserstein Barycenter example
+=====================================
+
+This example is designed to show how to use the Gromov-Wasserstein distance
+computation in POT.
+
+
+
+.. code-block:: python
+
+
+ # Author: Erwan Vautier <erwan.vautier@gmail.com>
+ # Nicolas Courty <ncourty@irisa.fr>
+ #
+ # License: MIT License
+
+
+ import numpy as np
+ import scipy as sp
+
+ import scipy.ndimage as spi
+ import matplotlib.pylab as pl
+ from sklearn import manifold
+ from sklearn.decomposition import PCA
+
+ import ot
+
+
+
+
+
+
+
+Smacof MDS
+ ----------
+
+ This function allows to find an embedding of points given a dissimilarity matrix
+ that will be given by the output of the algorithm
+
+
+
+.. code-block:: python
+
+
+
+ def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
+ """
+ Returns an interpolated point cloud following the dissimilarity matrix C
+ using SMACOF multidimensional scaling (MDS) in specific dimensionned
+ target space
+
+ Parameters
+ ----------
+ C : ndarray, shape (ns, ns)
+ dissimilarity matrix
+ dim : int
+ dimension of the targeted space
+ max_iter : int
+ Maximum number of iterations of the SMACOF algorithm for a single run
+ eps : float
+ relative tolerance w.r.t stress to declare converge
+
+ Returns
+ -------
+ npos : ndarray, shape (R, dim)
+ Embedded coordinates of the interpolated point cloud (defined with
+ one isometry)
+ """
+
+ rng = np.random.RandomState(seed=3)
+
+ mds = manifold.MDS(
+ dim,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity='precomputed',
+ n_init=1)
+ pos = mds.fit(C).embedding_
+
+ nmds = manifold.MDS(
+ 2,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity="precomputed",
+ random_state=rng,
+ n_init=1)
+ npos = nmds.fit_transform(C, init=pos)
+
+ return npos
+
+
+
+
+
+
+
+
+Data preparation
+ ----------------
+
+ The four distributions are constructed from 4 simple images
+
+
+
+.. code-block:: python
+
+
+
+ def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+ square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
+ cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
+ triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
+ star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
+
+ shapes = [square, cross, triangle, star]
+
+ S = 4
+ xs = [[] for i in range(S)]
+
+
+ for nb in range(4):
+ for i in range(8):
+ for j in range(8):
+ if shapes[nb][i, j] < 0.95:
+ xs[nb].append([j, 8 - i])
+
+ xs = np.array([np.array(xs[0]), np.array(xs[1]),
+ np.array(xs[2]), np.array(xs[3])])
+
+
+
+
+
+
+
+Barycenter computation
+----------------------
+
+
+
+.. code-block:: python
+
+
+
+ ns = [len(xs[s]) for s in range(S)]
+ n_samples = 30
+
+ """Compute all distances matrices for the four shapes"""
+ Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
+ Cs = [cs / cs.max() for cs in Cs]
+
+ ps = [ot.unif(ns[s]) for s in range(S)]
+ p = ot.unif(n_samples)
+
+
+ lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
+
+ Ct01 = [0 for i in range(2)]
+ for i in range(2):
+ Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
+ [ps[0], ps[1]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, tol=1e-3)
+
+ Ct02 = [0 for i in range(2)]
+ for i in range(2):
+ Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
+ [ps[0], ps[2]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, tol=1e-3)
+
+ Ct13 = [0 for i in range(2)]
+ for i in range(2):
+ Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
+ [ps[1], ps[3]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, tol=1e-3)
+
+ Ct23 = [0 for i in range(2)]
+ for i in range(2):
+ Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
+ [ps[2], ps[3]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, tol=1e-3)
+
+
+
+
+
+
+
+
+Visualization
+ -------------
+
+ The PCA helps in getting consistency between the rotations
+
+
+
+.. code-block:: python
+
+
+
+ clf = PCA(n_components=2)
+ npos = [0, 0, 0, 0]
+ npos = [smacof_mds(Cs[s], 2) for s in range(S)]
+
+ npost01 = [0, 0]
+ npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
+ npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
+
+ npost02 = [0, 0]
+ npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
+ npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
+
+ npost13 = [0, 0]
+ npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
+ npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
+
+ npost23 = [0, 0]
+ npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
+ npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
+
+
+ fig = pl.figure(figsize=(10, 10))
+
+ ax1 = pl.subplot2grid((4, 4), (0, 0))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
+
+ ax2 = pl.subplot2grid((4, 4), (0, 1))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
+
+ ax3 = pl.subplot2grid((4, 4), (0, 2))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
+
+ ax4 = pl.subplot2grid((4, 4), (0, 3))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
+
+ ax5 = pl.subplot2grid((4, 4), (1, 0))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
+
+ ax6 = pl.subplot2grid((4, 4), (1, 3))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
+
+ ax7 = pl.subplot2grid((4, 4), (2, 0))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
+
+ ax8 = pl.subplot2grid((4, 4), (2, 3))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
+
+ ax9 = pl.subplot2grid((4, 4), (3, 0))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
+
+ ax10 = pl.subplot2grid((4, 4), (3, 1))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
+
+ ax11 = pl.subplot2grid((4, 4), (3, 2))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
+
+ ax12 = pl.subplot2grid((4, 4), (3, 3))
+ pl.xlim((-1, 1))
+ pl.ylim((-1, 1))
+ ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_gromov_barycenter_001.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 8 minutes 43.875 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_gromov_barycenter.py <plot_gromov_barycenter.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_gromov_barycenter.ipynb <plot_gromov_barycenter.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_optim_OTreg.ipynb b/docs/source/auto_examples/plot_optim_OTreg.ipynb
index 5ded922..333331b 100644
--- a/docs/source/auto_examples/plot_optim_OTreg.ipynb
+++ b/docs/source/auto_examples/plot_optim_OTreg.ipynb
@@ -15,7 +15,7 @@
},
{
"source": [
- "\n# Regularized OT with generic solver\n\n\n\n\n"
+ "\n# Regularized OT with generic solver\n\n\nIllustrates the use of the generic solver for regularized OT with\nuser-designed regularization term. It uses Conditional gradient as in [6] and\ngeneralized Conditional Gradient as proposed in [5][7].\n\n\n[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for\nDomain Adaptation, in IEEE Transactions on Pattern Analysis and Machine\nIntelligence , vol.PP, no.99, pp.1-1.\n\n[6] Ferradans, S., Papadakis, N., Peyr\u00e9, G., & Aujol, J. F. (2014).\nRegularized discrete optimal transport. SIAM Journal on Imaging Sciences,\n7(3), 1853-1882.\n\n[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized\nconditional gradient: analysis of convergence and applications.\narXiv preprint arXiv:1510.06567.\n\n\n\n\n"
],
"cell_type": "markdown",
"metadata": {}
@@ -24,7 +24,97 @@
"execution_count": null,
"cell_type": "code",
"source": [
- "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\n\n\n\n#%% parameters\n\nn=100 # nb bins\n\n# bin positions\nx=np.arange(n,dtype=np.float64)\n\n# Gaussian distributions\na=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std\nb=ot.datasets.get_1D_gauss(n,m=60,s=10)\n\n# loss matrix\nM=ot.dist(x.reshape((n,1)),x.reshape((n,1)))\nM/=M.max()\n\n#%% EMD\n\nG0=ot.emd(a,b,M)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,G0,'OT matrix G0')\n\n#%% Example with Frobenius norm regularization\n\ndef f(G): return 0.5*np.sum(G**2)\ndef df(G): return G\n\nreg=1e-1\n\nGl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True)\n\npl.figure(3)\not.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg')\n\n#%% Example with entropic regularization\n\ndef f(G): return np.sum(G*np.log(G))\ndef df(G): return np.log(G)+1\n\nreg=1e-3\n\nGe=ot.optim.cg(a,b,M,reg,f,df,verbose=True)\n\npl.figure(4)\not.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg')\n\n#%% Example with Frobenius norm + entropic regularization with gcg\n\ndef f(G): return 0.5*np.sum(G**2)\ndef df(G): return G\n\nreg1=1e-3\nreg2=1e-1\n\nGel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True)\n\npl.figure(5)\not.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg')\npl.show()"
+ "import numpy as np\nimport matplotlib.pylab as pl\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std\nb = ot.datasets.get_1D_gauss(n, m=60, s=10)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Solve EMD\n---------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% EMD\n\nG0 = ot.emd(a, b, M)\n\npl.figure(3, figsize=(5, 5))\not.plot.plot1D_mat(a, b, G0, 'OT matrix G0')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Solve EMD with Frobenius norm regularization\n--------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% Example with Frobenius norm regularization\n\n\ndef f(G):\n return 0.5 * np.sum(G**2)\n\n\ndef df(G):\n return G\n\n\nreg = 1e-1\n\nGl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)\n\npl.figure(3)\not.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Solve EMD with entropic regularization\n--------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% Example with entropic regularization\n\n\ndef f(G):\n return np.sum(G * np.log(G))\n\n\ndef df(G):\n return np.log(G) + 1.\n\n\nreg = 1e-3\n\nGe = ot.optim.cg(a, b, M, reg, f, df, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Solve EMD with Frobenius norm + entropic regularization\n-------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "#%% Example with Frobenius norm + entropic regularization with gcg\n\n\ndef f(G):\n return 0.5 * np.sum(G**2)\n\n\ndef df(G):\n return G\n\n\nreg1 = 1e-3\nreg2 = 1e-1\n\nGel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)\n\npl.figure(5, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')\npl.show()"
],
"outputs": [],
"metadata": {
diff --git a/docs/source/auto_examples/plot_optim_OTreg.py b/docs/source/auto_examples/plot_optim_OTreg.py
index 8abb426..e1a737e 100644
--- a/docs/source/auto_examples/plot_optim_OTreg.py
+++ b/docs/source/auto_examples/plot_optim_OTreg.py
@@ -4,6 +4,24 @@
Regularized OT with generic solver
==================================
+Illustrates the use of the generic solver for regularized OT with
+user-designed regularization term. It uses Conditional gradient as in [6] and
+generalized Conditional Gradient as proposed in [5][7].
+
+
+[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for
+Domain Adaptation, in IEEE Transactions on Pattern Analysis and Machine
+Intelligence , vol.PP, no.99, pp.1-1.
+
+[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport. SIAM Journal on Imaging Sciences,
+7(3), 1853-1882.
+
+[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized
+conditional gradient: analysis of convergence and applications.
+arXiv preprint arXiv:1510.06567.
+
+
"""
@@ -12,63 +30,100 @@ import matplotlib.pylab as pl
import ot
+##############################################################################
+# Generate data
+# -------------
#%% parameters
-n=100 # nb bins
+n = 100 # nb bins
# bin positions
-x=np.arange(n,dtype=np.float64)
+x = np.arange(n, dtype=np.float64)
# Gaussian distributions
-a=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
-b=ot.datasets.get_1D_gauss(n,m=60,s=10)
+a = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
+b = ot.datasets.get_1D_gauss(n, m=60, s=10)
# loss matrix
-M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
-M/=M.max()
+M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+M /= M.max()
+
+##############################################################################
+# Solve EMD
+# ---------
#%% EMD
-G0=ot.emd(a,b,M)
+G0 = ot.emd(a, b, M)
-pl.figure(3)
-ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
+pl.figure(3, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+##############################################################################
+# Solve EMD with Frobenius norm regularization
+# --------------------------------------------
#%% Example with Frobenius norm regularization
-def f(G): return 0.5*np.sum(G**2)
-def df(G): return G
-reg=1e-1
+def f(G):
+ return 0.5 * np.sum(G**2)
+
+
+def df(G):
+ return G
+
-Gl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True)
+reg = 1e-1
+
+Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
pl.figure(3)
-ot.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg')
+ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')
+
+##############################################################################
+# Solve EMD with entropic regularization
+# --------------------------------------
#%% Example with entropic regularization
-def f(G): return np.sum(G*np.log(G))
-def df(G): return np.log(G)+1
-reg=1e-3
+def f(G):
+ return np.sum(G * np.log(G))
+
+
+def df(G):
+ return np.log(G) + 1.
+
+
+reg = 1e-3
-Ge=ot.optim.cg(a,b,M,reg,f,df,verbose=True)
+Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
-pl.figure(4)
-ot.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg')
+pl.figure(4, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')
+
+##############################################################################
+# Solve EMD with Frobenius norm + entropic regularization
+# -------------------------------------------------------
#%% Example with Frobenius norm + entropic regularization with gcg
-def f(G): return 0.5*np.sum(G**2)
-def df(G): return G
-reg1=1e-3
-reg2=1e-1
+def f(G):
+ return 0.5 * np.sum(G**2)
+
+
+def df(G):
+ return G
+
+
+reg1 = 1e-3
+reg2 = 1e-1
-Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True)
+Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)
-pl.figure(5)
-ot.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg')
-pl.show() \ No newline at end of file
+pl.figure(5, figsize=(5, 5))
+ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')
+pl.show()
diff --git a/docs/source/auto_examples/plot_optim_OTreg.rst b/docs/source/auto_examples/plot_optim_OTreg.rst
index 70cd26c..480149a 100644
--- a/docs/source/auto_examples/plot_optim_OTreg.rst
+++ b/docs/source/auto_examples/plot_optim_OTreg.rst
@@ -7,28 +7,126 @@
Regularized OT with generic solver
==================================
+Illustrates the use of the generic solver for regularized OT with
+user-designed regularization term. It uses Conditional gradient as in [6] and
+generalized Conditional Gradient as proposed in [5][7].
+[5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, Optimal Transport for
+Domain Adaptation, in IEEE Transactions on Pattern Analysis and Machine
+Intelligence , vol.PP, no.99, pp.1-1.
+[6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport. SIAM Journal on Imaging Sciences,
+7(3), 1853-1882.
+[7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized
+conditional gradient: analysis of convergence and applications.
+arXiv preprint arXiv:1510.06567.
-.. rst-class:: sphx-glr-horizontal
- *
- .. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_003.png
- :scale: 47
- *
- .. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_004.png
- :scale: 47
+.. code-block:: python
+
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ #%% parameters
+
+ n = 100 # nb bins
+
+ # bin positions
+ x = np.arange(n, dtype=np.float64)
+
+ # Gaussian distributions
+ a = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
+ b = ot.datasets.get_1D_gauss(n, m=60, s=10)
+
+ # loss matrix
+ M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
+ M /= M.max()
+
+
+
+
+
+
+
+Solve EMD
+---------
+
+
+
+.. code-block:: python
+
+
+ #%% EMD
+
+ G0 = ot.emd(a, b, M)
+
+ pl.figure(3, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_003.png
+ :align: center
+
+
+
+
+Solve EMD with Frobenius norm regularization
+--------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Example with Frobenius norm regularization
+
+
+ def f(G):
+ return 0.5 * np.sum(G**2)
+
+
+ def df(G):
+ return G
+
+
+ reg = 1e-1
+
+ Gl2 = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
+
+ pl.figure(3)
+ ot.plot.plot1D_mat(a, b, Gl2, 'OT matrix Frob. reg')
+
+
- *
- .. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_005.png
- :scale: 47
+.. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_004.png
+ :align: center
.. rst-class:: sphx-glr-script-out
@@ -258,312 +356,287 @@ Regularized OT with generic solver
It. |Loss |Delta loss
--------------------------------
200|1.663543e-01|-8.737134e-08
+
+
+Solve EMD with entropic regularization
+--------------------------------------
+
+
+
+.. code-block:: python
+
+
+ #%% Example with entropic regularization
+
+
+ def f(G):
+ return np.sum(G * np.log(G))
+
+
+ def df(G):
+ return np.log(G) + 1.
+
+
+ reg = 1e-3
+
+ Ge = ot.optim.cg(a, b, M, reg, f, df, verbose=True)
+
+ pl.figure(4, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Ge, 'OT matrix Entrop. reg')
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_006.png
+ :align: center
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
It. |Loss |Delta loss
--------------------------------
0|1.692289e-01|0.000000e+00
1|1.617643e-01|-4.614437e-02
- 2|1.612546e-01|-3.161037e-03
- 3|1.611040e-01|-9.349544e-04
- 4|1.610346e-01|-4.310179e-04
- 5|1.610072e-01|-1.701719e-04
- 6|1.609947e-01|-7.759814e-05
- 7|1.609934e-01|-7.941439e-06
- 8|1.609841e-01|-5.797180e-05
- 9|1.609838e-01|-1.559407e-06
- 10|1.609685e-01|-9.530282e-05
- 11|1.609666e-01|-1.142129e-05
- 12|1.609541e-01|-7.799970e-05
- 13|1.609496e-01|-2.780416e-05
- 14|1.609385e-01|-6.887105e-05
- 15|1.609334e-01|-3.174241e-05
- 16|1.609231e-01|-6.420777e-05
- 17|1.609115e-01|-7.189949e-05
- 18|1.608815e-01|-1.865331e-04
- 19|1.608799e-01|-1.013039e-05
+ 2|1.612639e-01|-3.102965e-03
+ 3|1.611291e-01|-8.371098e-04
+ 4|1.610468e-01|-5.110558e-04
+ 5|1.610198e-01|-1.672927e-04
+ 6|1.610130e-01|-4.232417e-05
+ 7|1.610090e-01|-2.513455e-05
+ 8|1.610002e-01|-5.443507e-05
+ 9|1.609996e-01|-3.657071e-06
+ 10|1.609948e-01|-2.998735e-05
+ 11|1.609695e-01|-1.569217e-04
+ 12|1.609533e-01|-1.010779e-04
+ 13|1.609520e-01|-8.043897e-06
+ 14|1.609465e-01|-3.415246e-05
+ 15|1.609386e-01|-4.898605e-05
+ 16|1.609324e-01|-3.837052e-05
+ 17|1.609298e-01|-1.617826e-05
+ 18|1.609184e-01|-7.080015e-05
+ 19|1.609083e-01|-6.273206e-05
It. |Loss |Delta loss
--------------------------------
- 20|1.608695e-01|-6.468606e-05
- 21|1.608686e-01|-5.738419e-06
- 22|1.608661e-01|-1.495923e-05
- 23|1.608657e-01|-2.784611e-06
- 24|1.608633e-01|-1.512408e-05
- 25|1.608624e-01|-5.397916e-06
- 26|1.608617e-01|-4.115218e-06
- 27|1.608561e-01|-3.503396e-05
- 28|1.608479e-01|-5.098773e-05
- 29|1.608452e-01|-1.659203e-05
- 30|1.608399e-01|-3.298319e-05
- 31|1.608330e-01|-4.302183e-05
- 32|1.608310e-01|-1.273465e-05
- 33|1.608280e-01|-1.827713e-05
- 34|1.608231e-01|-3.039842e-05
- 35|1.608212e-01|-1.229256e-05
- 36|1.608200e-01|-6.900556e-06
- 37|1.608159e-01|-2.554039e-05
- 38|1.608103e-01|-3.521137e-05
- 39|1.608058e-01|-2.795180e-05
+ 20|1.608988e-01|-5.940805e-05
+ 21|1.608853e-01|-8.380030e-05
+ 22|1.608844e-01|-5.185045e-06
+ 23|1.608824e-01|-1.279113e-05
+ 24|1.608819e-01|-3.156821e-06
+ 25|1.608783e-01|-2.205746e-05
+ 26|1.608764e-01|-1.189894e-05
+ 27|1.608755e-01|-5.474607e-06
+ 28|1.608737e-01|-1.144227e-05
+ 29|1.608676e-01|-3.775335e-05
+ 30|1.608638e-01|-2.348020e-05
+ 31|1.608627e-01|-6.863136e-06
+ 32|1.608529e-01|-6.110230e-05
+ 33|1.608487e-01|-2.641106e-05
+ 34|1.608409e-01|-4.823638e-05
+ 35|1.608373e-01|-2.256641e-05
+ 36|1.608338e-01|-2.132444e-05
+ 37|1.608310e-01|-1.786649e-05
+ 38|1.608260e-01|-3.103848e-05
+ 39|1.608206e-01|-3.321265e-05
It. |Loss |Delta loss
--------------------------------
- 40|1.608040e-01|-1.119118e-05
- 41|1.608027e-01|-8.193369e-06
- 42|1.607994e-01|-2.026719e-05
- 43|1.607985e-01|-5.819902e-06
- 44|1.607978e-01|-4.048170e-06
- 45|1.607978e-01|-3.007470e-07
- 46|1.607950e-01|-1.705375e-05
- 47|1.607927e-01|-1.430186e-05
- 48|1.607925e-01|-1.166526e-06
- 49|1.607911e-01|-9.069406e-06
- 50|1.607910e-01|-3.804209e-07
- 51|1.607910e-01|-5.942399e-08
- 52|1.607910e-01|-2.321380e-07
- 53|1.607907e-01|-1.877655e-06
- 54|1.607906e-01|-2.940224e-07
- 55|1.607877e-01|-1.814208e-05
- 56|1.607841e-01|-2.236496e-05
- 57|1.607810e-01|-1.951355e-05
- 58|1.607804e-01|-3.578228e-06
- 59|1.607789e-01|-9.442277e-06
+ 40|1.608201e-01|-3.054747e-06
+ 41|1.608195e-01|-4.198335e-06
+ 42|1.608193e-01|-8.458736e-07
+ 43|1.608159e-01|-2.153759e-05
+ 44|1.608115e-01|-2.738314e-05
+ 45|1.608108e-01|-3.960032e-06
+ 46|1.608081e-01|-1.675447e-05
+ 47|1.608072e-01|-5.976340e-06
+ 48|1.608046e-01|-1.604130e-05
+ 49|1.608020e-01|-1.617036e-05
+ 50|1.608014e-01|-3.957795e-06
+ 51|1.608011e-01|-1.292411e-06
+ 52|1.607998e-01|-8.431795e-06
+ 53|1.607964e-01|-2.127054e-05
+ 54|1.607947e-01|-1.021878e-05
+ 55|1.607947e-01|-3.560621e-07
+ 56|1.607900e-01|-2.929781e-05
+ 57|1.607890e-01|-5.740229e-06
+ 58|1.607858e-01|-2.039550e-05
+ 59|1.607836e-01|-1.319545e-05
It. |Loss |Delta loss
--------------------------------
- 60|1.607779e-01|-5.997371e-06
- 61|1.607754e-01|-1.564408e-05
- 62|1.607742e-01|-7.693285e-06
- 63|1.607727e-01|-9.030547e-06
- 64|1.607719e-01|-5.103894e-06
- 65|1.607693e-01|-1.605420e-05
- 66|1.607676e-01|-1.047837e-05
- 67|1.607675e-01|-6.026848e-07
- 68|1.607655e-01|-1.240216e-05
- 69|1.607632e-01|-1.434674e-05
- 70|1.607618e-01|-8.829808e-06
- 71|1.607606e-01|-7.581824e-06
- 72|1.607590e-01|-1.009457e-05
- 73|1.607586e-01|-2.222963e-06
- 74|1.607577e-01|-5.564775e-06
- 75|1.607574e-01|-1.932763e-06
- 76|1.607573e-01|-8.148685e-07
- 77|1.607554e-01|-1.187660e-05
- 78|1.607546e-01|-4.557651e-06
- 79|1.607537e-01|-5.911902e-06
+ 60|1.607826e-01|-6.378947e-06
+ 61|1.607808e-01|-1.145102e-05
+ 62|1.607776e-01|-1.941743e-05
+ 63|1.607743e-01|-2.087422e-05
+ 64|1.607741e-01|-1.310249e-06
+ 65|1.607738e-01|-1.682752e-06
+ 66|1.607691e-01|-2.913936e-05
+ 67|1.607671e-01|-1.288855e-05
+ 68|1.607654e-01|-1.002448e-05
+ 69|1.607641e-01|-8.209492e-06
+ 70|1.607632e-01|-5.588467e-06
+ 71|1.607619e-01|-8.050388e-06
+ 72|1.607618e-01|-9.417493e-07
+ 73|1.607598e-01|-1.210509e-05
+ 74|1.607591e-01|-4.392914e-06
+ 75|1.607579e-01|-7.759587e-06
+ 76|1.607574e-01|-2.760280e-06
+ 77|1.607556e-01|-1.146469e-05
+ 78|1.607550e-01|-3.689456e-06
+ 79|1.607550e-01|-4.065631e-08
It. |Loss |Delta loss
--------------------------------
- 80|1.607529e-01|-4.710187e-06
- 81|1.607528e-01|-8.866080e-07
- 82|1.607522e-01|-3.620627e-06
- 83|1.607514e-01|-5.091281e-06
- 84|1.607498e-01|-9.932095e-06
- 85|1.607487e-01|-6.852804e-06
- 86|1.607478e-01|-5.373596e-06
- 87|1.607473e-01|-3.287295e-06
- 88|1.607470e-01|-1.666655e-06
- 89|1.607469e-01|-5.293790e-07
- 90|1.607466e-01|-2.051914e-06
- 91|1.607456e-01|-6.422797e-06
- 92|1.607456e-01|-1.110433e-07
- 93|1.607451e-01|-2.803849e-06
- 94|1.607451e-01|-2.608066e-07
- 95|1.607441e-01|-6.290352e-06
- 96|1.607429e-01|-7.298455e-06
- 97|1.607429e-01|-8.969905e-09
- 98|1.607427e-01|-7.923968e-07
- 99|1.607427e-01|-3.519286e-07
+ 80|1.607539e-01|-6.555681e-06
+ 81|1.607528e-01|-7.177470e-06
+ 82|1.607527e-01|-5.306068e-07
+ 83|1.607514e-01|-7.816045e-06
+ 84|1.607511e-01|-2.301970e-06
+ 85|1.607504e-01|-4.281072e-06
+ 86|1.607503e-01|-7.821886e-07
+ 87|1.607480e-01|-1.403013e-05
+ 88|1.607480e-01|-1.169298e-08
+ 89|1.607473e-01|-4.235982e-06
+ 90|1.607470e-01|-1.717105e-06
+ 91|1.607470e-01|-6.148402e-09
+ 92|1.607462e-01|-5.396481e-06
+ 93|1.607461e-01|-5.194954e-07
+ 94|1.607450e-01|-6.525707e-06
+ 95|1.607442e-01|-5.332060e-06
+ 96|1.607439e-01|-1.682093e-06
+ 97|1.607437e-01|-1.594796e-06
+ 98|1.607435e-01|-7.923812e-07
+ 99|1.607420e-01|-9.738552e-06
It. |Loss |Delta loss
--------------------------------
- 100|1.607426e-01|-3.563804e-07
- 101|1.607410e-01|-1.004042e-05
- 102|1.607410e-01|-2.124801e-07
- 103|1.607398e-01|-7.556935e-06
- 104|1.607398e-01|-7.606853e-08
- 105|1.607385e-01|-8.058684e-06
- 106|1.607383e-01|-7.393061e-07
- 107|1.607381e-01|-1.504958e-06
- 108|1.607377e-01|-2.508807e-06
- 109|1.607371e-01|-4.004631e-06
- 110|1.607365e-01|-3.580156e-06
- 111|1.607364e-01|-2.563573e-07
- 112|1.607354e-01|-6.390137e-06
- 113|1.607348e-01|-4.119553e-06
- 114|1.607339e-01|-5.299475e-06
- 115|1.607335e-01|-2.316767e-06
- 116|1.607330e-01|-3.444737e-06
- 117|1.607324e-01|-3.467980e-06
- 118|1.607320e-01|-2.374632e-06
- 119|1.607319e-01|-7.978255e-07
+ 100|1.607419e-01|-1.022448e-07
+ 101|1.607419e-01|-4.865999e-07
+ 102|1.607418e-01|-7.092012e-07
+ 103|1.607408e-01|-5.861815e-06
+ 104|1.607402e-01|-3.953266e-06
+ 105|1.607395e-01|-3.969572e-06
+ 106|1.607390e-01|-3.612075e-06
+ 107|1.607377e-01|-7.683735e-06
+ 108|1.607365e-01|-7.777599e-06
+ 109|1.607364e-01|-2.335096e-07
+ 110|1.607364e-01|-4.562036e-07
+ 111|1.607360e-01|-2.089538e-06
+ 112|1.607356e-01|-2.755355e-06
+ 113|1.607349e-01|-4.501960e-06
+ 114|1.607347e-01|-1.160544e-06
+ 115|1.607346e-01|-6.289450e-07
+ 116|1.607345e-01|-2.092146e-07
+ 117|1.607336e-01|-5.990866e-06
+ 118|1.607330e-01|-3.348498e-06
+ 119|1.607328e-01|-1.256222e-06
It. |Loss |Delta loss
--------------------------------
- 120|1.607312e-01|-4.221434e-06
- 121|1.607310e-01|-1.324597e-06
- 122|1.607304e-01|-3.650359e-06
- 123|1.607298e-01|-3.732712e-06
- 124|1.607295e-01|-1.994082e-06
- 125|1.607289e-01|-3.954139e-06
- 126|1.607286e-01|-1.532372e-06
- 127|1.607286e-01|-1.167223e-07
- 128|1.607283e-01|-2.157376e-06
- 129|1.607279e-01|-2.253077e-06
- 130|1.607274e-01|-3.301532e-06
- 131|1.607269e-01|-2.650754e-06
- 132|1.607264e-01|-3.595551e-06
- 133|1.607262e-01|-1.159425e-06
- 134|1.607258e-01|-2.512411e-06
- 135|1.607255e-01|-1.998792e-06
- 136|1.607251e-01|-2.486536e-06
- 137|1.607246e-01|-2.782996e-06
- 138|1.607246e-01|-2.922470e-07
- 139|1.607242e-01|-2.071131e-06
+ 120|1.607320e-01|-5.418353e-06
+ 121|1.607318e-01|-8.296189e-07
+ 122|1.607311e-01|-4.381608e-06
+ 123|1.607310e-01|-8.913901e-07
+ 124|1.607309e-01|-3.808821e-07
+ 125|1.607302e-01|-4.608994e-06
+ 126|1.607294e-01|-5.063777e-06
+ 127|1.607290e-01|-2.532835e-06
+ 128|1.607285e-01|-2.870049e-06
+ 129|1.607284e-01|-4.892812e-07
+ 130|1.607281e-01|-1.760452e-06
+ 131|1.607279e-01|-1.727139e-06
+ 132|1.607275e-01|-2.220706e-06
+ 133|1.607271e-01|-2.516930e-06
+ 134|1.607269e-01|-1.201434e-06
+ 135|1.607269e-01|-2.183459e-09
+ 136|1.607262e-01|-4.223011e-06
+ 137|1.607258e-01|-2.530202e-06
+ 138|1.607258e-01|-1.857260e-07
+ 139|1.607256e-01|-1.401957e-06
It. |Loss |Delta loss
--------------------------------
- 140|1.607237e-01|-3.154193e-06
- 141|1.607235e-01|-1.194962e-06
- 142|1.607232e-01|-2.035251e-06
- 143|1.607232e-01|-6.027855e-08
- 144|1.607229e-01|-1.555696e-06
- 145|1.607228e-01|-1.081740e-06
- 146|1.607225e-01|-1.881070e-06
- 147|1.607224e-01|-4.100096e-07
- 148|1.607223e-01|-7.785200e-07
- 149|1.607222e-01|-2.094072e-07
- 150|1.607220e-01|-1.440814e-06
- 151|1.607217e-01|-1.997794e-06
- 152|1.607214e-01|-2.011022e-06
- 153|1.607212e-01|-8.808854e-07
- 154|1.607211e-01|-7.245877e-07
- 155|1.607207e-01|-2.217159e-06
- 156|1.607201e-01|-3.817891e-06
- 157|1.607200e-01|-7.409600e-07
- 158|1.607198e-01|-1.497698e-06
- 159|1.607195e-01|-1.729666e-06
+ 140|1.607250e-01|-3.242751e-06
+ 141|1.607247e-01|-2.308071e-06
+ 142|1.607247e-01|-4.730700e-08
+ 143|1.607246e-01|-4.240229e-07
+ 144|1.607242e-01|-2.484810e-06
+ 145|1.607238e-01|-2.539206e-06
+ 146|1.607234e-01|-2.535574e-06
+ 147|1.607231e-01|-1.954802e-06
+ 148|1.607228e-01|-1.765447e-06
+ 149|1.607228e-01|-1.620007e-08
+ 150|1.607222e-01|-3.615783e-06
+ 151|1.607222e-01|-8.668516e-08
+ 152|1.607215e-01|-4.000673e-06
+ 153|1.607213e-01|-1.774103e-06
+ 154|1.607213e-01|-6.328834e-09
+ 155|1.607209e-01|-2.418783e-06
+ 156|1.607208e-01|-2.848492e-07
+ 157|1.607207e-01|-8.836043e-07
+ 158|1.607205e-01|-1.192836e-06
+ 159|1.607202e-01|-1.638022e-06
It. |Loss |Delta loss
--------------------------------
- 160|1.607195e-01|-2.115187e-07
- 161|1.607192e-01|-1.643727e-06
- 162|1.607192e-01|-1.712969e-07
- 163|1.607189e-01|-1.805877e-06
- 164|1.607189e-01|-1.209827e-07
- 165|1.607185e-01|-2.060002e-06
- 166|1.607182e-01|-1.961341e-06
- 167|1.607181e-01|-1.020366e-06
- 168|1.607179e-01|-9.760982e-07
- 169|1.607178e-01|-7.219236e-07
- 170|1.607175e-01|-1.837718e-06
- 171|1.607174e-01|-3.337578e-07
- 172|1.607173e-01|-5.298564e-07
- 173|1.607173e-01|-6.864278e-08
- 174|1.607173e-01|-2.008419e-07
- 175|1.607171e-01|-1.375630e-06
- 176|1.607168e-01|-1.911257e-06
- 177|1.607167e-01|-2.709815e-07
- 178|1.607167e-01|-1.390953e-07
- 179|1.607165e-01|-1.199675e-06
- It. |Loss |Delta loss
- --------------------------------
- 180|1.607165e-01|-1.457259e-07
- 181|1.607163e-01|-1.049154e-06
- 182|1.607163e-01|-2.753577e-09
- 183|1.607163e-01|-6.972814e-09
- 184|1.607161e-01|-1.552100e-06
- 185|1.607159e-01|-1.068596e-06
- 186|1.607157e-01|-1.247724e-06
- 187|1.607155e-01|-1.158164e-06
- 188|1.607155e-01|-2.616199e-07
- 189|1.607154e-01|-3.595874e-07
- 190|1.607154e-01|-5.334527e-08
- 191|1.607153e-01|-3.452744e-07
- 192|1.607153e-01|-1.239593e-07
- 193|1.607152e-01|-8.184984e-07
- 194|1.607150e-01|-1.316308e-06
- 195|1.607150e-01|-7.100882e-09
- 196|1.607148e-01|-1.393958e-06
- 197|1.607146e-01|-1.242735e-06
- 198|1.607144e-01|-1.123993e-06
- 199|1.607143e-01|-3.512071e-07
- It. |Loss |Delta loss
- --------------------------------
- 200|1.607143e-01|-2.151971e-10
- It. |Loss |Delta loss
- --------------------------------
- 0|1.693084e-01|0.000000e+00
- 1|1.610121e-01|-5.152589e-02
- 2|1.609378e-01|-4.622297e-04
- 3|1.609284e-01|-5.830043e-05
- 4|1.609284e-01|-1.111580e-12
-
+ 160|1.607202e-01|-3.670914e-08
+ 161|1.607197e-01|-3.153709e-06
+ 162|1.607197e-01|-2.419565e-09
+ 163|1.607194e-01|-2.136882e-06
+ 164|1.607194e-01|-1.173754e-09
+ 165|1.607192e-01|-8.169238e-07
+ 166|1.607191e-01|-9.218755e-07
+ 167|1.607189e-01|-9.459255e-07
+ 168|1.607187e-01|-1.294835e-06
+ 169|1.607186e-01|-5.797668e-07
+ 170|1.607186e-01|-4.706272e-08
+ 171|1.607183e-01|-1.753383e-06
+ 172|1.607183e-01|-1.681573e-07
+ 173|1.607183e-01|-2.563971e-10
+Solve EMD with Frobenius norm + entropic regularization
+-------------------------------------------------------
-|
.. code-block:: python
- import numpy as np
- import matplotlib.pylab as pl
- import ot
-
-
-
- #%% parameters
-
- n=100 # nb bins
-
- # bin positions
- x=np.arange(n,dtype=np.float64)
-
- # Gaussian distributions
- a=ot.datasets.get_1D_gauss(n,m=20,s=5) # m= mean, s= std
- b=ot.datasets.get_1D_gauss(n,m=60,s=10)
-
- # loss matrix
- M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
- M/=M.max()
-
- #%% EMD
+ #%% Example with Frobenius norm + entropic regularization with gcg
- G0=ot.emd(a,b,M)
- pl.figure(3)
- ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
+ def f(G):
+ return 0.5 * np.sum(G**2)
- #%% Example with Frobenius norm regularization
- def f(G): return 0.5*np.sum(G**2)
- def df(G): return G
+ def df(G):
+ return G
- reg=1e-1
- Gl2=ot.optim.cg(a,b,M,reg,f,df,verbose=True)
+ reg1 = 1e-3
+ reg2 = 1e-1
- pl.figure(3)
- ot.plot.plot1D_mat(a,b,Gl2,'OT matrix Frob. reg')
+ Gel2 = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True)
- #%% Example with entropic regularization
+ pl.figure(5, figsize=(5, 5))
+ ot.plot.plot1D_mat(a, b, Gel2, 'OT entropic + matrix Frob. reg')
+ pl.show()
- def f(G): return np.sum(G*np.log(G))
- def df(G): return np.log(G)+1
- reg=1e-3
- Ge=ot.optim.cg(a,b,M,reg,f,df,verbose=True)
+.. image:: /auto_examples/images/sphx_glr_plot_optim_OTreg_008.png
+ :align: center
- pl.figure(4)
- ot.plot.plot1D_mat(a,b,Ge,'OT matrix Entrop. reg')
- #%% Example with Frobenius norm + entropic regularization with gcg
+.. rst-class:: sphx-glr-script-out
- def f(G): return 0.5*np.sum(G**2)
- def df(G): return G
+ Out::
- reg1=1e-3
- reg2=1e-1
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|1.693084e-01|0.000000e+00
+ 1|1.610121e-01|-5.152589e-02
+ 2|1.609378e-01|-4.622297e-04
+ 3|1.609284e-01|-5.830043e-05
+ 4|1.609284e-01|-1.111407e-12
- Gel2=ot.optim.gcg(a,b,M,reg1,reg2,f,df,verbose=True)
- pl.figure(5)
- ot.plot.plot1D_mat(a,b,Gel2,'OT entropic + matrix Frob. reg')
- pl.show()
-**Total running time of the script:** ( 0 minutes 2.319 seconds)
+**Total running time of the script:** ( 0 minutes 2.800 seconds)
diff --git a/docs/source/auto_examples/plot_otda_classes.ipynb b/docs/source/auto_examples/plot_otda_classes.ipynb
new file mode 100644
index 0000000..6754fa5
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_classes.ipynb
@@ -0,0 +1,126 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# OT for domain adaptation\n\n\nThis example introduces a domain adaptation in a 2D setting and the 4 OTDA\napproaches currently supported in POT.\n\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport matplotlib.pylab as pl\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "n_source_samples = 150\nn_target_samples = 150\n\nXs, ys = ot.datasets.get_data_classif('3gauss', n_source_samples)\nXt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# EMD Transport\not_emd = ot.da.EMDTransport()\not_emd.fit(Xs=Xs, Xt=Xt)\n\n# Sinkhorn Transport\not_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn.fit(Xs=Xs, Xt=Xt)\n\n# Sinkhorn Transport with Group lasso regularization\not_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)\not_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)\n\n# Sinkhorn Transport with Group lasso regularization l1l2\not_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,\n verbose=True)\not_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)\n\n# transport source samples onto target samples\ntransp_Xs_emd = ot_emd.transform(Xs=Xs)\ntransp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)\ntransp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)\ntransp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Fig 1 : plots source and target samples\n---------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(1, figsize=(10, 5))\npl.subplot(1, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Source samples')\n\npl.subplot(1, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Target samples')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Fig 2 : plot optimal couplings and transported samples\n------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "param_img = {'interpolation': 'nearest', 'cmap': 'spectral'}\n\npl.figure(2, figsize=(15, 8))\npl.subplot(2, 4, 1)\npl.imshow(ot_emd.coupling_, **param_img)\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nEMDTransport')\n\npl.subplot(2, 4, 2)\npl.imshow(ot_sinkhorn.coupling_, **param_img)\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornTransport')\n\npl.subplot(2, 4, 3)\npl.imshow(ot_lpl1.coupling_, **param_img)\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornLpl1Transport')\n\npl.subplot(2, 4, 4)\npl.imshow(ot_l1l2.coupling_, **param_img)\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornL1l2Transport')\n\npl.subplot(2, 4, 5)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.3)\npl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.xticks([])\npl.yticks([])\npl.title('Transported samples\\nEmdTransport')\npl.legend(loc=\"lower left\")\n\npl.subplot(2, 4, 6)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.3)\npl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.xticks([])\npl.yticks([])\npl.title('Transported samples\\nSinkhornTransport')\n\npl.subplot(2, 4, 7)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.3)\npl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.xticks([])\npl.yticks([])\npl.title('Transported samples\\nSinkhornLpl1Transport')\n\npl.subplot(2, 4, 8)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.3)\npl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.xticks([])\npl.yticks([])\npl.title('Transported samples\\nSinkhornL1l2Transport')\npl.tight_layout()\n\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_classes.py b/docs/source/auto_examples/plot_otda_classes.py
new file mode 100644
index 0000000..b14c11a
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_classes.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+"""
+========================
+OT for domain adaptation
+========================
+
+This example introduces a domain adaptation in a 2D setting and the 4 OTDA
+approaches currently supported in POT.
+
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+
+n_source_samples = 150
+n_target_samples = 150
+
+Xs, ys = ot.datasets.get_data_classif('3gauss', n_source_samples)
+Xt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples)
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMD Transport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport with Group lasso regularization
+ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+# Sinkhorn Transport with Group lasso regularization l1l2
+ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,
+ verbose=True)
+ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+# transport source samples onto target samples
+transp_Xs_emd = ot_emd.transform(Xs=Xs)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)
+
+
+##############################################################################
+# Fig 1 : plots source and target samples
+# ---------------------------------------
+
+pl.figure(1, figsize=(10, 5))
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 2 : plot optimal couplings and transported samples
+# ------------------------------------------------------
+
+param_img = {'interpolation': 'nearest', 'cmap': 'spectral'}
+
+pl.figure(2, figsize=(15, 8))
+pl.subplot(2, 4, 1)
+pl.imshow(ot_emd.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nEMDTransport')
+
+pl.subplot(2, 4, 2)
+pl.imshow(ot_sinkhorn.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornTransport')
+
+pl.subplot(2, 4, 3)
+pl.imshow(ot_lpl1.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+pl.subplot(2, 4, 4)
+pl.imshow(ot_l1l2.coupling_, **param_img)
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornL1l2Transport')
+
+pl.subplot(2, 4, 5)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc="lower left")
+
+pl.subplot(2, 4, 6)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornTransport')
+
+pl.subplot(2, 4, 7)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornLpl1Transport')
+
+pl.subplot(2, 4, 8)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.xticks([])
+pl.yticks([])
+pl.title('Transported samples\nSinkhornL1l2Transport')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_classes.rst b/docs/source/auto_examples/plot_otda_classes.rst
new file mode 100644
index 0000000..a5ab285
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_classes.rst
@@ -0,0 +1,258 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_classes.py:
+
+
+========================
+OT for domain adaptation
+========================
+
+This example introduces a domain adaptation in a 2D setting and the 4 OTDA
+approaches currently supported in POT.
+
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ n_source_samples = 150
+ n_target_samples = 150
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', n_source_samples)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', n_target_samples)
+
+
+
+
+
+
+
+
+Instantiate the different transport algorithms and fit them
+-----------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # EMD Transport
+ ot_emd = ot.da.EMDTransport()
+ ot_emd.fit(Xs=Xs, Xt=Xt)
+
+ # Sinkhorn Transport
+ ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+ # Sinkhorn Transport with Group lasso regularization
+ ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+ # Sinkhorn Transport with Group lasso regularization l1l2
+ ot_l1l2 = ot.da.SinkhornL1l2Transport(reg_e=1e-1, reg_cl=2e0, max_iter=20,
+ verbose=True)
+ ot_l1l2.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+ # transport source samples onto target samples
+ transp_Xs_emd = ot_emd.transform(Xs=Xs)
+ transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+ transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+ transp_Xs_l1l2 = ot_l1l2.transform(Xs=Xs)
+
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|1.003747e+01|0.000000e+00
+ 1|1.953263e+00|-4.138821e+00
+ 2|1.744456e+00|-1.196969e-01
+ 3|1.689268e+00|-3.267022e-02
+ 4|1.666355e+00|-1.374998e-02
+ 5|1.656125e+00|-6.177356e-03
+ 6|1.651753e+00|-2.646960e-03
+ 7|1.647261e+00|-2.726957e-03
+ 8|1.642274e+00|-3.036672e-03
+ 9|1.639926e+00|-1.431818e-03
+ 10|1.638750e+00|-7.173837e-04
+ 11|1.637558e+00|-7.281753e-04
+ 12|1.636248e+00|-8.002067e-04
+ 13|1.634555e+00|-1.036074e-03
+ 14|1.633547e+00|-6.166646e-04
+ 15|1.633531e+00|-1.022614e-05
+ 16|1.632957e+00|-3.510986e-04
+ 17|1.632853e+00|-6.380944e-05
+ 18|1.632704e+00|-9.122988e-05
+ 19|1.632237e+00|-2.861276e-04
+ It. |Loss |Delta loss
+ --------------------------------
+ 20|1.632174e+00|-3.896483e-05
+
+
+Fig 1 : plots source and target samples
+---------------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(10, 5))
+ pl.subplot(1, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Source samples')
+
+ pl.subplot(1, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Target samples')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_classes_001.png
+ :align: center
+
+
+
+
+Fig 2 : plot optimal couplings and transported samples
+------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ param_img = {'interpolation': 'nearest', 'cmap': 'spectral'}
+
+ pl.figure(2, figsize=(15, 8))
+ pl.subplot(2, 4, 1)
+ pl.imshow(ot_emd.coupling_, **param_img)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nEMDTransport')
+
+ pl.subplot(2, 4, 2)
+ pl.imshow(ot_sinkhorn.coupling_, **param_img)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornTransport')
+
+ pl.subplot(2, 4, 3)
+ pl.imshow(ot_lpl1.coupling_, **param_img)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+ pl.subplot(2, 4, 4)
+ pl.imshow(ot_l1l2.coupling_, **param_img)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornL1l2Transport')
+
+ pl.subplot(2, 4, 5)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+ pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Transported samples\nEmdTransport')
+ pl.legend(loc="lower left")
+
+ pl.subplot(2, 4, 6)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+ pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Transported samples\nSinkhornTransport')
+
+ pl.subplot(2, 4, 7)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+ pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Transported samples\nSinkhornLpl1Transport')
+
+ pl.subplot(2, 4, 8)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.3)
+ pl.scatter(transp_Xs_l1l2[:, 0], transp_Xs_l1l2[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Transported samples\nSinkhornL1l2Transport')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_classes_003.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 2.308 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_classes.py <plot_otda_classes.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_classes.ipynb <plot_otda_classes.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_color_images.ipynb b/docs/source/auto_examples/plot_otda_color_images.ipynb
new file mode 100644
index 0000000..2daf406
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_color_images.ipynb
@@ -0,0 +1,144 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# OT for image color adaptation\n\n\nThis example presents a way of transferring colors between two image\nwith Optimal Transport as introduced in [6]\n\n[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).\nRegularized discrete optimal transport.\nSIAM Journal on Imaging Sciences, 7(3), 1853-1882.\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport numpy as np\nfrom scipy import ndimage\nimport matplotlib.pylab as pl\nimport ot\n\n\nr = np.random.RandomState(42)\n\n\ndef im2mat(I):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\n\n\ndef mat2im(X, shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\n\ndef minmax(I):\n return np.clip(I, 0, 1)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Loading images\nI1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256\nI2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256\n\nX1 = im2mat(I1)\nX2 = im2mat(I2)\n\n# training samples\nnb = 1000\nidx1 = r.randint(X1.shape[0], size=(nb,))\nidx2 = r.randint(X2.shape[0], size=(nb,))\n\nXs = X1[idx1, :]\nXt = X2[idx2, :]"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot original image\n-------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(1, figsize=(6.4, 3))\n\npl.subplot(1, 2, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.imshow(I2)\npl.axis('off')\npl.title('Image 2')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Scatter plot of colors\n----------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(2, figsize=(6.4, 3))\n\npl.subplot(1, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 2')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# EMDTransport\not_emd = ot.da.EMDTransport()\not_emd.fit(Xs=Xs, Xt=Xt)\n\n# SinkhornTransport\not_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn.fit(Xs=Xs, Xt=Xt)\n\n# prediction between images (using out of sample prediction as in [6])\ntransp_Xs_emd = ot_emd.transform(Xs=X1)\ntransp_Xt_emd = ot_emd.inverse_transform(Xt=X2)\n\ntransp_Xs_sinkhorn = ot_emd.transform(Xs=X1)\ntransp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2)\n\nI1t = minmax(mat2im(transp_Xs_emd, I1.shape))\nI2t = minmax(mat2im(transp_Xt_emd, I2.shape))\n\nI1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))\nI2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot new images\n---------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(3, figsize=(8, 4))\n\npl.subplot(2, 3, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Image 1')\n\npl.subplot(2, 3, 2)\npl.imshow(I1t)\npl.axis('off')\npl.title('Image 1 Adapt')\n\npl.subplot(2, 3, 3)\npl.imshow(I1te)\npl.axis('off')\npl.title('Image 1 Adapt (reg)')\n\npl.subplot(2, 3, 4)\npl.imshow(I2)\npl.axis('off')\npl.title('Image 2')\n\npl.subplot(2, 3, 5)\npl.imshow(I2t)\npl.axis('off')\npl.title('Image 2 Adapt')\n\npl.subplot(2, 3, 6)\npl.imshow(I2te)\npl.axis('off')\npl.title('Image 2 Adapt (reg)')\npl.tight_layout()\n\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_color_images.py b/docs/source/auto_examples/plot_otda_color_images.py
new file mode 100644
index 0000000..e77aec0
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_color_images.py
@@ -0,0 +1,165 @@
+# -*- coding: utf-8 -*-
+"""
+=============================
+OT for image color adaptation
+=============================
+
+This example presents a way of transferring colors between two image
+with Optimal Transport as introduced in [6]
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport.
+SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+from scipy import ndimage
+import matplotlib.pylab as pl
+import ot
+
+
+r = np.random.RandomState(42)
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+##############################################################################
+# Generate data
+# -------------
+
+# Loading images
+I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+X1 = im2mat(I1)
+X2 = im2mat(I2)
+
+# training samples
+nb = 1000
+idx1 = r.randint(X1.shape[0], size=(nb,))
+idx2 = r.randint(X2.shape[0], size=(nb,))
+
+Xs = X1[idx1, :]
+Xt = X2[idx2, :]
+
+
+##############################################################################
+# Plot original image
+# -------------------
+
+pl.figure(1, figsize=(6.4, 3))
+
+pl.subplot(1, 2, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Image 2')
+
+
+##############################################################################
+# Scatter plot of colors
+# ----------------------
+
+pl.figure(2, figsize=(6.4, 3))
+
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 2')
+pl.tight_layout()
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMDTransport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# SinkhornTransport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# prediction between images (using out of sample prediction as in [6])
+transp_Xs_emd = ot_emd.transform(Xs=X1)
+transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
+
+transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
+transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2)
+
+I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
+I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
+
+I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
+
+
+##############################################################################
+# Plot new images
+# ---------------
+
+pl.figure(3, figsize=(8, 4))
+
+pl.subplot(2, 3, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Image 1')
+
+pl.subplot(2, 3, 2)
+pl.imshow(I1t)
+pl.axis('off')
+pl.title('Image 1 Adapt')
+
+pl.subplot(2, 3, 3)
+pl.imshow(I1te)
+pl.axis('off')
+pl.title('Image 1 Adapt (reg)')
+
+pl.subplot(2, 3, 4)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Image 2')
+
+pl.subplot(2, 3, 5)
+pl.imshow(I2t)
+pl.axis('off')
+pl.title('Image 2 Adapt')
+
+pl.subplot(2, 3, 6)
+pl.imshow(I2te)
+pl.axis('off')
+pl.title('Image 2 Adapt (reg)')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_color_images.rst b/docs/source/auto_examples/plot_otda_color_images.rst
new file mode 100644
index 0000000..9c31ba7
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_color_images.rst
@@ -0,0 +1,257 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_color_images.py:
+
+
+=============================
+OT for image color adaptation
+=============================
+
+This example presents a way of transferring colors between two image
+with Optimal Transport as introduced in [6]
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).
+Regularized discrete optimal transport.
+SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import numpy as np
+ from scipy import ndimage
+ import matplotlib.pylab as pl
+ import ot
+
+
+ r = np.random.RandomState(42)
+
+
+ def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+ def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+ def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ # Loading images
+ I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+ I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+ X1 = im2mat(I1)
+ X2 = im2mat(I2)
+
+ # training samples
+ nb = 1000
+ idx1 = r.randint(X1.shape[0], size=(nb,))
+ idx2 = r.randint(X2.shape[0], size=(nb,))
+
+ Xs = X1[idx1, :]
+ Xt = X2[idx2, :]
+
+
+
+
+
+
+
+
+Plot original image
+-------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(6.4, 3))
+
+ pl.subplot(1, 2, 1)
+ pl.imshow(I1)
+ pl.axis('off')
+ pl.title('Image 1')
+
+ pl.subplot(1, 2, 2)
+ pl.imshow(I2)
+ pl.axis('off')
+ pl.title('Image 2')
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_color_images_001.png
+ :align: center
+
+
+
+
+Scatter plot of colors
+----------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2, figsize=(6.4, 3))
+
+ pl.subplot(1, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+ pl.axis([0, 1, 0, 1])
+ pl.xlabel('Red')
+ pl.ylabel('Blue')
+ pl.title('Image 1')
+
+ pl.subplot(1, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+ pl.axis([0, 1, 0, 1])
+ pl.xlabel('Red')
+ pl.ylabel('Blue')
+ pl.title('Image 2')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_color_images_003.png
+ :align: center
+
+
+
+
+Instantiate the different transport algorithms and fit them
+-----------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # EMDTransport
+ ot_emd = ot.da.EMDTransport()
+ ot_emd.fit(Xs=Xs, Xt=Xt)
+
+ # SinkhornTransport
+ ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+ # prediction between images (using out of sample prediction as in [6])
+ transp_Xs_emd = ot_emd.transform(Xs=X1)
+ transp_Xt_emd = ot_emd.inverse_transform(Xt=X2)
+
+ transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
+ transp_Xt_sinkhorn = ot_emd.inverse_transform(Xt=X2)
+
+ I1t = minmax(mat2im(transp_Xs_emd, I1.shape))
+ I2t = minmax(mat2im(transp_Xt_emd, I2.shape))
+
+ I1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+ I2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))
+
+
+
+
+
+
+
+
+Plot new images
+---------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(3, figsize=(8, 4))
+
+ pl.subplot(2, 3, 1)
+ pl.imshow(I1)
+ pl.axis('off')
+ pl.title('Image 1')
+
+ pl.subplot(2, 3, 2)
+ pl.imshow(I1t)
+ pl.axis('off')
+ pl.title('Image 1 Adapt')
+
+ pl.subplot(2, 3, 3)
+ pl.imshow(I1te)
+ pl.axis('off')
+ pl.title('Image 1 Adapt (reg)')
+
+ pl.subplot(2, 3, 4)
+ pl.imshow(I2)
+ pl.axis('off')
+ pl.title('Image 2')
+
+ pl.subplot(2, 3, 5)
+ pl.imshow(I2t)
+ pl.axis('off')
+ pl.title('Image 2 Adapt')
+
+ pl.subplot(2, 3, 6)
+ pl.imshow(I2te)
+ pl.axis('off')
+ pl.title('Image 2 Adapt (reg)')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_color_images_005.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 3 minutes 16.469 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_color_images.py <plot_otda_color_images.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_color_images.ipynb <plot_otda_color_images.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_d2.ipynb b/docs/source/auto_examples/plot_otda_d2.ipynb
new file mode 100644
index 0000000..7bfcc9a
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_d2.ipynb
@@ -0,0 +1,144 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# OT for domain adaptation on empirical distributions\n\n\nThis example introduces a domain adaptation in a 2D setting. It explicits\nthe problem of domain adaptation and introduces some optimal transport\napproaches to solve it.\n\nQuantities such as optimal couplings, greater coupling coefficients and\ntransported samples are represented in order to give a visual understanding\nof what the transport methods are doing.\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport matplotlib.pylab as pl\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "n_samples_source = 150\nn_samples_target = 150\n\nXs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)\nXt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)\n\n# Cost matrix\nM = ot.dist(Xs, Xt, metric='sqeuclidean')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# EMD Transport\not_emd = ot.da.EMDTransport()\not_emd.fit(Xs=Xs, Xt=Xt)\n\n# Sinkhorn Transport\not_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn.fit(Xs=Xs, Xt=Xt)\n\n# Sinkhorn Transport with Group lasso regularization\not_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)\not_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)\n\n# transport source samples onto target samples\ntransp_Xs_emd = ot_emd.transform(Xs=Xs)\ntransp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)\ntransp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Fig 1 : plots source and target samples + matrix of pairwise distance\n---------------------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(1, figsize=(10, 10))\npl.subplot(2, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Source samples')\n\npl.subplot(2, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Target samples')\n\npl.subplot(2, 2, 3)\npl.imshow(M, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Matrix of pairwise distances')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Fig 2 : plots optimal couplings for the different methods\n---------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(2, figsize=(10, 6))\n\npl.subplot(2, 3, 1)\npl.imshow(ot_emd.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nEMDTransport')\n\npl.subplot(2, 3, 2)\npl.imshow(ot_sinkhorn.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornTransport')\n\npl.subplot(2, 3, 3)\npl.imshow(ot_lpl1.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSinkhornLpl1Transport')\n\npl.subplot(2, 3, 4)\not.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1])\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.title('Main coupling coefficients\\nEMDTransport')\n\npl.subplot(2, 3, 5)\not.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1])\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.title('Main coupling coefficients\\nSinkhornTransport')\n\npl.subplot(2, 3, 6)\not.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1])\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.title('Main coupling coefficients\\nSinkhornLpl1Transport')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Fig 3 : plot transported samples\n--------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# display transported samples\npl.figure(4, figsize=(10, 4))\npl.subplot(1, 3, 1)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nEmdTransport')\npl.legend(loc=0)\npl.xticks([])\npl.yticks([])\n\npl.subplot(1, 3, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nSinkhornTransport')\npl.xticks([])\npl.yticks([])\n\npl.subplot(1, 3, 3)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nSinkhornLpl1Transport')\npl.xticks([])\npl.yticks([])\n\npl.tight_layout()\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_d2.py b/docs/source/auto_examples/plot_otda_d2.py
new file mode 100644
index 0000000..e53d7d6
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_d2.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+"""
+===================================================
+OT for domain adaptation on empirical distributions
+===================================================
+
+This example introduces a domain adaptation in a 2D setting. It explicits
+the problem of domain adaptation and introduces some optimal transport
+approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# generate data
+# -------------
+
+n_samples_source = 150
+n_samples_target = 150
+
+Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
+Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
+
+# Cost matrix
+M = ot.dist(Xs, Xt, metric='sqeuclidean')
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# EMD Transport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+# Sinkhorn Transport with Group lasso regularization
+ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+# transport source samples onto target samples
+transp_Xs_emd = ot_emd.transform(Xs=Xs)
+transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+
+
+##############################################################################
+# Fig 1 : plots source and target samples + matrix of pairwise distance
+# ---------------------------------------------------------------------
+
+pl.figure(1, figsize=(10, 10))
+pl.subplot(2, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+
+pl.subplot(2, 2, 3)
+pl.imshow(M, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Matrix of pairwise distances')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 2 : plots optimal couplings for the different methods
+# ---------------------------------------------------------
+pl.figure(2, figsize=(10, 6))
+
+pl.subplot(2, 3, 1)
+pl.imshow(ot_emd.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nEMDTransport')
+
+pl.subplot(2, 3, 2)
+pl.imshow(ot_sinkhorn.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornTransport')
+
+pl.subplot(2, 3, 3)
+pl.imshow(ot_lpl1.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+pl.subplot(2, 3, 4)
+ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1])
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.title('Main coupling coefficients\nEMDTransport')
+
+pl.subplot(2, 3, 5)
+ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1])
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.title('Main coupling coefficients\nSinkhornTransport')
+
+pl.subplot(2, 3, 6)
+ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1])
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.title('Main coupling coefficients\nSinkhornLpl1Transport')
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 3 : plot transported samples
+# --------------------------------
+
+# display transported samples
+pl.figure(4, figsize=(10, 4))
+pl.subplot(1, 3, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc=0)
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 3, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornTransport')
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 3, 3)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornLpl1Transport')
+pl.xticks([])
+pl.yticks([])
+
+pl.tight_layout()
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_d2.rst b/docs/source/auto_examples/plot_otda_d2.rst
new file mode 100644
index 0000000..1bbe6d9
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_d2.rst
@@ -0,0 +1,264 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_d2.py:
+
+
+===================================================
+OT for domain adaptation on empirical distributions
+===================================================
+
+This example introduces a domain adaptation in a 2D setting. It explicits
+the problem of domain adaptation and introduces some optimal transport
+approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+
+generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ n_samples_source = 150
+ n_samples_target = 150
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
+
+ # Cost matrix
+ M = ot.dist(Xs, Xt, metric='sqeuclidean')
+
+
+
+
+
+
+
+
+Instantiate the different transport algorithms and fit them
+-----------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # EMD Transport
+ ot_emd = ot.da.EMDTransport()
+ ot_emd.fit(Xs=Xs, Xt=Xt)
+
+ # Sinkhorn Transport
+ ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+
+ # Sinkhorn Transport with Group lasso regularization
+ ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0)
+ ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+ # transport source samples onto target samples
+ transp_Xs_emd = ot_emd.transform(Xs=Xs)
+ transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs)
+ transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs)
+
+
+
+
+
+
+
+
+Fig 1 : plots source and target samples + matrix of pairwise distance
+---------------------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(10, 10))
+ pl.subplot(2, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Source samples')
+
+ pl.subplot(2, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Target samples')
+
+ pl.subplot(2, 2, 3)
+ pl.imshow(M, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Matrix of pairwise distances')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_d2_001.png
+ :align: center
+
+
+
+
+Fig 2 : plots optimal couplings for the different methods
+---------------------------------------------------------
+
+
+
+.. code-block:: python
+
+ pl.figure(2, figsize=(10, 6))
+
+ pl.subplot(2, 3, 1)
+ pl.imshow(ot_emd.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nEMDTransport')
+
+ pl.subplot(2, 3, 2)
+ pl.imshow(ot_sinkhorn.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornTransport')
+
+ pl.subplot(2, 3, 3)
+ pl.imshow(ot_lpl1.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSinkhornLpl1Transport')
+
+ pl.subplot(2, 3, 4)
+ ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1])
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Main coupling coefficients\nEMDTransport')
+
+ pl.subplot(2, 3, 5)
+ ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1])
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Main coupling coefficients\nSinkhornTransport')
+
+ pl.subplot(2, 3, 6)
+ ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1])
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Main coupling coefficients\nSinkhornLpl1Transport')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_d2_003.png
+ :align: center
+
+
+
+
+Fig 3 : plot transported samples
+--------------------------------
+
+
+
+.. code-block:: python
+
+
+ # display transported samples
+ pl.figure(4, figsize=(10, 4))
+ pl.subplot(1, 3, 1)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nEmdTransport')
+ pl.legend(loc=0)
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.subplot(1, 3, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nSinkhornTransport')
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.subplot(1, 3, 3)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nSinkhornLpl1Transport')
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.tight_layout()
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_d2_006.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 47.000 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_d2.py <plot_otda_d2.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_d2.ipynb <plot_otda_d2.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_mapping.ipynb b/docs/source/auto_examples/plot_otda_mapping.ipynb
new file mode 100644
index 0000000..0374146
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping.ipynb
@@ -0,0 +1,126 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# OT mapping estimation for domain adaptation\n\n\nThis example presents how to use MappingTransport to estimate at the same\ntime both the coupling transport and approximate the transport map with either\na linear or a kernelized mapping as introduced in [8].\n\n[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,\n \"Mapping estimation for discrete optimal transport\",\n Neural Information Processing Systems (NIPS), 2016.\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "n_source_samples = 100\nn_target_samples = 100\ntheta = 2 * np.pi / 20\nnoise_level = 0.1\n\nXs, ys = ot.datasets.get_data_classif(\n 'gaussrot', n_source_samples, nz=noise_level)\nXs_new, _ = ot.datasets.get_data_classif(\n 'gaussrot', n_source_samples, nz=noise_level)\nXt, yt = ot.datasets.get_data_classif(\n 'gaussrot', n_target_samples, theta=theta, nz=noise_level)\n\n# one of the target mode changes its variance (no linear mapping)\nXt[yt == 2] *= 3\nXt = Xt + 4"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot data\n---------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(1, (10, 5))\npl.clf()\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.legend(loc=0)\npl.title('Source and target distributions')"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# MappingTransport with linear kernel\not_mapping_linear = ot.da.MappingTransport(\n kernel=\"linear\", mu=1e0, eta=1e-8, bias=True,\n max_iter=20, verbose=True)\n\not_mapping_linear.fit(Xs=Xs, Xt=Xt)\n\n# for original source samples, transform applies barycentric mapping\ntransp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)\n\n# for out of source samples, transform applies the linear mapping\ntransp_Xs_linear_new = ot_mapping_linear.transform(Xs=Xs_new)\n\n\n# MappingTransport with gaussian kernel\not_mapping_gaussian = ot.da.MappingTransport(\n kernel=\"gaussian\", eta=1e-5, mu=1e-1, bias=True, sigma=1,\n max_iter=10, verbose=True)\not_mapping_gaussian.fit(Xs=Xs, Xt=Xt)\n\n# for original source samples, transform applies barycentric mapping\ntransp_Xs_gaussian = ot_mapping_gaussian.transform(Xs=Xs)\n\n# for out of source samples, transform applies the gaussian mapping\ntransp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot transported samples\n------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(2)\npl.clf()\npl.subplot(2, 2, 1)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=.2)\npl.scatter(transp_Xs_linear[:, 0], transp_Xs_linear[:, 1], c=ys, marker='+',\n label='Mapped source samples')\npl.title(\"Bary. mapping (linear)\")\npl.legend(loc=0)\n\npl.subplot(2, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=.2)\npl.scatter(transp_Xs_linear_new[:, 0], transp_Xs_linear_new[:, 1],\n c=ys, marker='+', label='Learned mapping')\npl.title(\"Estim. mapping (linear)\")\n\npl.subplot(2, 2, 3)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=.2)\npl.scatter(transp_Xs_gaussian[:, 0], transp_Xs_gaussian[:, 1], c=ys,\n marker='+', label='barycentric mapping')\npl.title(\"Bary. mapping (kernel)\")\n\npl.subplot(2, 2, 4)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=.2)\npl.scatter(transp_Xs_gaussian_new[:, 0], transp_Xs_gaussian_new[:, 1], c=ys,\n marker='+', label='Learned mapping')\npl.title(\"Estim. mapping (kernel)\")\npl.tight_layout()\n\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_mapping.py b/docs/source/auto_examples/plot_otda_mapping.py
new file mode 100644
index 0000000..167c3a1
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping.py
@@ -0,0 +1,125 @@
+# -*- coding: utf-8 -*-
+"""
+===========================================
+OT mapping estimation for domain adaptation
+===========================================
+
+This example presents how to use MappingTransport to estimate at the same
+time both the coupling transport and approximate the transport map with either
+a linear or a kernelized mapping as introduced in [8].
+
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+
+n_source_samples = 100
+n_target_samples = 100
+theta = 2 * np.pi / 20
+noise_level = 0.1
+
+Xs, ys = ot.datasets.get_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+Xs_new, _ = ot.datasets.get_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+Xt, yt = ot.datasets.get_data_classif(
+ 'gaussrot', n_target_samples, theta=theta, nz=noise_level)
+
+# one of the target mode changes its variance (no linear mapping)
+Xt[yt == 2] *= 3
+Xt = Xt + 4
+
+##############################################################################
+# Plot data
+# ---------
+
+pl.figure(1, (10, 5))
+pl.clf()
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.legend(loc=0)
+pl.title('Source and target distributions')
+
+
+##############################################################################
+# Instantiate the different transport algorithms and fit them
+# -----------------------------------------------------------
+
+# MappingTransport with linear kernel
+ot_mapping_linear = ot.da.MappingTransport(
+ kernel="linear", mu=1e0, eta=1e-8, bias=True,
+ max_iter=20, verbose=True)
+
+ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+# for original source samples, transform applies barycentric mapping
+transp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)
+
+# for out of source samples, transform applies the linear mapping
+transp_Xs_linear_new = ot_mapping_linear.transform(Xs=Xs_new)
+
+
+# MappingTransport with gaussian kernel
+ot_mapping_gaussian = ot.da.MappingTransport(
+ kernel="gaussian", eta=1e-5, mu=1e-1, bias=True, sigma=1,
+ max_iter=10, verbose=True)
+ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+# for original source samples, transform applies barycentric mapping
+transp_Xs_gaussian = ot_mapping_gaussian.transform(Xs=Xs)
+
+# for out of source samples, transform applies the gaussian mapping
+transp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new)
+
+
+##############################################################################
+# Plot transported samples
+# ------------------------
+
+pl.figure(2)
+pl.clf()
+pl.subplot(2, 2, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_linear[:, 0], transp_Xs_linear[:, 1], c=ys, marker='+',
+ label='Mapped source samples')
+pl.title("Bary. mapping (linear)")
+pl.legend(loc=0)
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_linear_new[:, 0], transp_Xs_linear_new[:, 1],
+ c=ys, marker='+', label='Learned mapping')
+pl.title("Estim. mapping (linear)")
+
+pl.subplot(2, 2, 3)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_gaussian[:, 0], transp_Xs_gaussian[:, 1], c=ys,
+ marker='+', label='barycentric mapping')
+pl.title("Bary. mapping (kernel)")
+
+pl.subplot(2, 2, 4)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+pl.scatter(transp_Xs_gaussian_new[:, 0], transp_Xs_gaussian_new[:, 1], c=ys,
+ marker='+', label='Learned mapping')
+pl.title("Estim. mapping (kernel)")
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_mapping.rst b/docs/source/auto_examples/plot_otda_mapping.rst
new file mode 100644
index 0000000..1e3a709
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping.rst
@@ -0,0 +1,230 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_mapping.py:
+
+
+===========================================
+OT mapping estimation for domain adaptation
+===========================================
+
+This example presents how to use MappingTransport to estimate at the same
+time both the coupling transport and approximate the transport map with either
+a linear or a kernelized mapping as introduced in [8].
+
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import numpy as np
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ n_source_samples = 100
+ n_target_samples = 100
+ theta = 2 * np.pi / 20
+ noise_level = 0.1
+
+ Xs, ys = ot.datasets.get_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+ Xs_new, _ = ot.datasets.get_data_classif(
+ 'gaussrot', n_source_samples, nz=noise_level)
+ Xt, yt = ot.datasets.get_data_classif(
+ 'gaussrot', n_target_samples, theta=theta, nz=noise_level)
+
+ # one of the target mode changes its variance (no linear mapping)
+ Xt[yt == 2] *= 3
+ Xt = Xt + 4
+
+
+
+
+
+
+
+Plot data
+---------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, (10, 5))
+ pl.clf()
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.legend(loc=0)
+ pl.title('Source and target distributions')
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_001.png
+ :align: center
+
+
+
+
+Instantiate the different transport algorithms and fit them
+-----------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # MappingTransport with linear kernel
+ ot_mapping_linear = ot.da.MappingTransport(
+ kernel="linear", mu=1e0, eta=1e-8, bias=True,
+ max_iter=20, verbose=True)
+
+ ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+ # for original source samples, transform applies barycentric mapping
+ transp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)
+
+ # for out of source samples, transform applies the linear mapping
+ transp_Xs_linear_new = ot_mapping_linear.transform(Xs=Xs_new)
+
+
+ # MappingTransport with gaussian kernel
+ ot_mapping_gaussian = ot.da.MappingTransport(
+ kernel="gaussian", eta=1e-5, mu=1e-1, bias=True, sigma=1,
+ max_iter=10, verbose=True)
+ ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+ # for original source samples, transform applies barycentric mapping
+ transp_Xs_gaussian = ot_mapping_gaussian.transform(Xs=Xs)
+
+ # for out of source samples, transform applies the gaussian mapping
+ transp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new)
+
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|4.231423e+03|0.000000e+00
+ 1|4.217955e+03|-3.182835e-03
+ 2|4.217580e+03|-8.885864e-05
+ 3|4.217451e+03|-3.043162e-05
+ 4|4.217368e+03|-1.978325e-05
+ 5|4.217312e+03|-1.338471e-05
+ 6|4.217307e+03|-1.000290e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|4.257004e+02|0.000000e+00
+ 1|4.208978e+02|-1.128168e-02
+ 2|4.205168e+02|-9.052112e-04
+ 3|4.203566e+02|-3.810681e-04
+ 4|4.202570e+02|-2.369884e-04
+ 5|4.201844e+02|-1.726132e-04
+ 6|4.201341e+02|-1.196461e-04
+ 7|4.200941e+02|-9.525441e-05
+ 8|4.200630e+02|-7.405552e-05
+ 9|4.200377e+02|-6.031884e-05
+ 10|4.200168e+02|-4.968324e-05
+
+
+Plot transported samples
+------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2)
+ pl.clf()
+ pl.subplot(2, 2, 1)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+ pl.scatter(transp_Xs_linear[:, 0], transp_Xs_linear[:, 1], c=ys, marker='+',
+ label='Mapped source samples')
+ pl.title("Bary. mapping (linear)")
+ pl.legend(loc=0)
+
+ pl.subplot(2, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+ pl.scatter(transp_Xs_linear_new[:, 0], transp_Xs_linear_new[:, 1],
+ c=ys, marker='+', label='Learned mapping')
+ pl.title("Estim. mapping (linear)")
+
+ pl.subplot(2, 2, 3)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+ pl.scatter(transp_Xs_gaussian[:, 0], transp_Xs_gaussian[:, 1], c=ys,
+ marker='+', label='barycentric mapping')
+ pl.title("Bary. mapping (kernel)")
+
+ pl.subplot(2, 2, 4)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=.2)
+ pl.scatter(transp_Xs_gaussian_new[:, 0], transp_Xs_gaussian_new[:, 1], c=ys,
+ marker='+', label='Learned mapping')
+ pl.title("Estim. mapping (kernel)")
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_003.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.970 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_mapping.py <plot_otda_mapping.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_mapping.ipynb <plot_otda_mapping.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_mapping_colors_images.ipynb b/docs/source/auto_examples/plot_otda_mapping_colors_images.ipynb
new file mode 100644
index 0000000..56caa8a
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping_colors_images.ipynb
@@ -0,0 +1,144 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# OT for image color adaptation with mapping estimation\n\n\nOT for domain adaptation with image color adaptation [6] with mapping\nestimation [8].\n\n[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized\n discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3),\n 1853-1882.\n[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, \"Mapping estimation for\n discrete optimal transport\", Neural Information Processing Systems (NIPS),\n 2016.\n\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport numpy as np\nfrom scipy import ndimage\nimport matplotlib.pylab as pl\nimport ot\n\nr = np.random.RandomState(42)\n\n\ndef im2mat(I):\n \"\"\"Converts and image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\n\n\ndef mat2im(X, shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\n\ndef minmax(I):\n return np.clip(I, 0, 1)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Loading images\nI1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256\nI2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256\n\n\nX1 = im2mat(I1)\nX2 = im2mat(I2)\n\n# training samples\nnb = 1000\nidx1 = r.randint(X1.shape[0], size=(nb,))\nidx2 = r.randint(X2.shape[0], size=(nb,))\n\nXs = X1[idx1, :]\nXt = X2[idx2, :]"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Domain adaptation for pixel distribution transfer\n-------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# EMDTransport\not_emd = ot.da.EMDTransport()\not_emd.fit(Xs=Xs, Xt=Xt)\ntransp_Xs_emd = ot_emd.transform(Xs=X1)\nImage_emd = minmax(mat2im(transp_Xs_emd, I1.shape))\n\n# SinkhornTransport\not_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn.fit(Xs=Xs, Xt=Xt)\ntransp_Xs_sinkhorn = ot_emd.transform(Xs=X1)\nImage_sinkhorn = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))\n\not_mapping_linear = ot.da.MappingTransport(\n mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True)\not_mapping_linear.fit(Xs=Xs, Xt=Xt)\n\nX1tl = ot_mapping_linear.transform(Xs=X1)\nImage_mapping_linear = minmax(mat2im(X1tl, I1.shape))\n\not_mapping_gaussian = ot.da.MappingTransport(\n mu=1e0, eta=1e-2, sigma=1, bias=False, max_iter=10, verbose=True)\not_mapping_gaussian.fit(Xs=Xs, Xt=Xt)\n\nX1tn = ot_mapping_gaussian.transform(Xs=X1) # use the estimated mapping\nImage_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot original images\n--------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(1, figsize=(6.4, 3))\npl.subplot(1, 2, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.imshow(I2)\npl.axis('off')\npl.title('Image 2')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot pixel values distribution\n------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(2, figsize=(6.4, 5))\n\npl.subplot(1, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 2')\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Plot transformed images\n-----------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(2, figsize=(10, 5))\n\npl.subplot(2, 3, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Im. 1')\n\npl.subplot(2, 3, 4)\npl.imshow(I2)\npl.axis('off')\npl.title('Im. 2')\n\npl.subplot(2, 3, 2)\npl.imshow(Image_emd)\npl.axis('off')\npl.title('EmdTransport')\n\npl.subplot(2, 3, 5)\npl.imshow(Image_sinkhorn)\npl.axis('off')\npl.title('SinkhornTransport')\n\npl.subplot(2, 3, 3)\npl.imshow(Image_mapping_linear)\npl.axis('off')\npl.title('MappingTransport (linear)')\n\npl.subplot(2, 3, 6)\npl.imshow(Image_mapping_gaussian)\npl.axis('off')\npl.title('MappingTransport (gaussian)')\npl.tight_layout()\n\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_mapping_colors_images.py b/docs/source/auto_examples/plot_otda_mapping_colors_images.py
new file mode 100644
index 0000000..5f1e844
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping_colors_images.py
@@ -0,0 +1,174 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================================
+OT for image color adaptation with mapping estimation
+=====================================================
+
+OT for domain adaptation with image color adaptation [6] with mapping
+estimation [8].
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized
+ discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3),
+ 1853-1882.
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
+ discrete optimal transport", Neural Information Processing Systems (NIPS),
+ 2016.
+
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import numpy as np
+from scipy import ndimage
+import matplotlib.pylab as pl
+import ot
+
+r = np.random.RandomState(42)
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+##############################################################################
+# Generate data
+# -------------
+
+# Loading images
+I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+
+X1 = im2mat(I1)
+X2 = im2mat(I2)
+
+# training samples
+nb = 1000
+idx1 = r.randint(X1.shape[0], size=(nb,))
+idx2 = r.randint(X2.shape[0], size=(nb,))
+
+Xs = X1[idx1, :]
+Xt = X2[idx2, :]
+
+
+##############################################################################
+# Domain adaptation for pixel distribution transfer
+# -------------------------------------------------
+
+# EMDTransport
+ot_emd = ot.da.EMDTransport()
+ot_emd.fit(Xs=Xs, Xt=Xt)
+transp_Xs_emd = ot_emd.transform(Xs=X1)
+Image_emd = minmax(mat2im(transp_Xs_emd, I1.shape))
+
+# SinkhornTransport
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
+Image_sinkhorn = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+
+ot_mapping_linear = ot.da.MappingTransport(
+ mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True)
+ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+X1tl = ot_mapping_linear.transform(Xs=X1)
+Image_mapping_linear = minmax(mat2im(X1tl, I1.shape))
+
+ot_mapping_gaussian = ot.da.MappingTransport(
+ mu=1e0, eta=1e-2, sigma=1, bias=False, max_iter=10, verbose=True)
+ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+X1tn = ot_mapping_gaussian.transform(Xs=X1) # use the estimated mapping
+Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
+
+
+##############################################################################
+# Plot original images
+# --------------------
+
+pl.figure(1, figsize=(6.4, 3))
+pl.subplot(1, 2, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Image 2')
+pl.tight_layout()
+
+
+##############################################################################
+# Plot pixel values distribution
+# ------------------------------
+
+pl.figure(2, figsize=(6.4, 5))
+
+pl.subplot(1, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 1')
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+pl.axis([0, 1, 0, 1])
+pl.xlabel('Red')
+pl.ylabel('Blue')
+pl.title('Image 2')
+pl.tight_layout()
+
+
+##############################################################################
+# Plot transformed images
+# -----------------------
+
+pl.figure(2, figsize=(10, 5))
+
+pl.subplot(2, 3, 1)
+pl.imshow(I1)
+pl.axis('off')
+pl.title('Im. 1')
+
+pl.subplot(2, 3, 4)
+pl.imshow(I2)
+pl.axis('off')
+pl.title('Im. 2')
+
+pl.subplot(2, 3, 2)
+pl.imshow(Image_emd)
+pl.axis('off')
+pl.title('EmdTransport')
+
+pl.subplot(2, 3, 5)
+pl.imshow(Image_sinkhorn)
+pl.axis('off')
+pl.title('SinkhornTransport')
+
+pl.subplot(2, 3, 3)
+pl.imshow(Image_mapping_linear)
+pl.axis('off')
+pl.title('MappingTransport (linear)')
+
+pl.subplot(2, 3, 6)
+pl.imshow(Image_mapping_gaussian)
+pl.axis('off')
+pl.title('MappingTransport (gaussian)')
+pl.tight_layout()
+
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_mapping_colors_images.rst b/docs/source/auto_examples/plot_otda_mapping_colors_images.rst
new file mode 100644
index 0000000..8394fb0
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_mapping_colors_images.rst
@@ -0,0 +1,305 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_mapping_colors_images.py:
+
+
+=====================================================
+OT for image color adaptation with mapping estimation
+=====================================================
+
+OT for domain adaptation with image color adaptation [6] with mapping
+estimation [8].
+
+[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014). Regularized
+ discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3),
+ 1853-1882.
+[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for
+ discrete optimal transport", Neural Information Processing Systems (NIPS),
+ 2016.
+
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import numpy as np
+ from scipy import ndimage
+ import matplotlib.pylab as pl
+ import ot
+
+ r = np.random.RandomState(42)
+
+
+ def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+ def mat2im(X, shape):
+ """Converts back a matrix to an image"""
+ return X.reshape(shape)
+
+
+ def minmax(I):
+ return np.clip(I, 0, 1)
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ # Loading images
+ I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256
+ I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256
+
+
+ X1 = im2mat(I1)
+ X2 = im2mat(I2)
+
+ # training samples
+ nb = 1000
+ idx1 = r.randint(X1.shape[0], size=(nb,))
+ idx2 = r.randint(X2.shape[0], size=(nb,))
+
+ Xs = X1[idx1, :]
+ Xt = X2[idx2, :]
+
+
+
+
+
+
+
+
+Domain adaptation for pixel distribution transfer
+-------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ # EMDTransport
+ ot_emd = ot.da.EMDTransport()
+ ot_emd.fit(Xs=Xs, Xt=Xt)
+ transp_Xs_emd = ot_emd.transform(Xs=X1)
+ Image_emd = minmax(mat2im(transp_Xs_emd, I1.shape))
+
+ # SinkhornTransport
+ ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
+ transp_Xs_sinkhorn = ot_emd.transform(Xs=X1)
+ Image_sinkhorn = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))
+
+ ot_mapping_linear = ot.da.MappingTransport(
+ mu=1e0, eta=1e-8, bias=True, max_iter=20, verbose=True)
+ ot_mapping_linear.fit(Xs=Xs, Xt=Xt)
+
+ X1tl = ot_mapping_linear.transform(Xs=X1)
+ Image_mapping_linear = minmax(mat2im(X1tl, I1.shape))
+
+ ot_mapping_gaussian = ot.da.MappingTransport(
+ mu=1e0, eta=1e-2, sigma=1, bias=False, max_iter=10, verbose=True)
+ ot_mapping_gaussian.fit(Xs=Xs, Xt=Xt)
+
+ X1tn = ot_mapping_gaussian.transform(Xs=X1) # use the estimated mapping
+ Image_mapping_gaussian = minmax(mat2im(X1tn, I1.shape))
+
+
+
+
+
+
+.. rst-class:: sphx-glr-script-out
+
+ Out::
+
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|3.680518e+02|0.000000e+00
+ 1|3.592439e+02|-2.393116e-02
+ 2|3.590632e+02|-5.030248e-04
+ 3|3.589698e+02|-2.601358e-04
+ 4|3.589118e+02|-1.614977e-04
+ 5|3.588724e+02|-1.097608e-04
+ 6|3.588436e+02|-8.035205e-05
+ 7|3.588215e+02|-6.141923e-05
+ 8|3.588042e+02|-4.832627e-05
+ 9|3.587902e+02|-3.909574e-05
+ 10|3.587786e+02|-3.225418e-05
+ 11|3.587688e+02|-2.712592e-05
+ 12|3.587605e+02|-2.314041e-05
+ 13|3.587534e+02|-1.991287e-05
+ 14|3.587471e+02|-1.744348e-05
+ 15|3.587416e+02|-1.544523e-05
+ 16|3.587367e+02|-1.364654e-05
+ 17|3.587323e+02|-1.230435e-05
+ 18|3.587284e+02|-1.093370e-05
+ 19|3.587276e+02|-2.052728e-06
+ It. |Loss |Delta loss
+ --------------------------------
+ 0|3.784758e+02|0.000000e+00
+ 1|3.646352e+02|-3.656911e-02
+ 2|3.642861e+02|-9.574714e-04
+ 3|3.641523e+02|-3.672061e-04
+ 4|3.640788e+02|-2.020990e-04
+ 5|3.640321e+02|-1.282701e-04
+ 6|3.640002e+02|-8.751240e-05
+ 7|3.639765e+02|-6.521203e-05
+ 8|3.639582e+02|-5.007767e-05
+ 9|3.639439e+02|-3.938917e-05
+ 10|3.639323e+02|-3.187865e-05
+
+
+Plot original images
+--------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(6.4, 3))
+ pl.subplot(1, 2, 1)
+ pl.imshow(I1)
+ pl.axis('off')
+ pl.title('Image 1')
+
+ pl.subplot(1, 2, 2)
+ pl.imshow(I2)
+ pl.axis('off')
+ pl.title('Image 2')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_001.png
+ :align: center
+
+
+
+
+Plot pixel values distribution
+------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2, figsize=(6.4, 5))
+
+ pl.subplot(1, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)
+ pl.axis([0, 1, 0, 1])
+ pl.xlabel('Red')
+ pl.ylabel('Blue')
+ pl.title('Image 1')
+
+ pl.subplot(1, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)
+ pl.axis([0, 1, 0, 1])
+ pl.xlabel('Red')
+ pl.ylabel('Blue')
+ pl.title('Image 2')
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_003.png
+ :align: center
+
+
+
+
+Plot transformed images
+-----------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2, figsize=(10, 5))
+
+ pl.subplot(2, 3, 1)
+ pl.imshow(I1)
+ pl.axis('off')
+ pl.title('Im. 1')
+
+ pl.subplot(2, 3, 4)
+ pl.imshow(I2)
+ pl.axis('off')
+ pl.title('Im. 2')
+
+ pl.subplot(2, 3, 2)
+ pl.imshow(Image_emd)
+ pl.axis('off')
+ pl.title('EmdTransport')
+
+ pl.subplot(2, 3, 5)
+ pl.imshow(Image_sinkhorn)
+ pl.axis('off')
+ pl.title('SinkhornTransport')
+
+ pl.subplot(2, 3, 3)
+ pl.imshow(Image_mapping_linear)
+ pl.axis('off')
+ pl.title('MappingTransport (linear)')
+
+ pl.subplot(2, 3, 6)
+ pl.imshow(Image_mapping_gaussian)
+ pl.axis('off')
+ pl.title('MappingTransport (gaussian)')
+ pl.tight_layout()
+
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_mapping_colors_images_004.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 2 minutes 52.212 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_mapping_colors_images.py <plot_otda_mapping_colors_images.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_mapping_colors_images.ipynb <plot_otda_mapping_colors_images.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/plot_otda_semi_supervised.ipynb b/docs/source/auto_examples/plot_otda_semi_supervised.ipynb
new file mode 100644
index 0000000..783bf84
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_semi_supervised.ipynb
@@ -0,0 +1,144 @@
+{
+ "nbformat_minor": 0,
+ "nbformat": 4,
+ "cells": [
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "%matplotlib inline"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "\n# OTDA unsupervised vs semi-supervised setting\n\n\nThis example introduces a semi supervised domain adaptation in a 2D setting.\nIt explicits the problem of semi supervised domain adaptation and introduces\nsome optimal transport approaches to solve it.\n\nQuantities such as optimal couplings, greater coupling coefficients and\ntransported samples are represented in order to give a visual understanding\nof what the transport methods are doing.\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport matplotlib.pylab as pl\nimport ot"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Generate data\n-------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "n_samples_source = 150\nn_samples_target = 150\n\nXs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)\nXt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Transport source samples onto target samples\n--------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# unsupervised domain adaptation\not_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn_un.fit(Xs=Xs, Xt=Xt)\ntransp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)\n\n# semi-supervised domain adaptation\not_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)\ntransp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)\n\n# semi supervised DA uses available labaled target samples to modify the cost\n# matrix involved in the OT problem. The cost of transporting a source sample\n# of class A onto a target sample of class B != A is set to infinite, or a\n# very large value\n\n# note that in the present case we consider that all the target samples are\n# labeled. For daily applications, some target sample might not have labels,\n# in this case the element of yt corresponding to these samples should be\n# filled with -1.\n\n# Warning: we recall that -1 cannot be used as a class label"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Fig 1 : plots source and target samples + matrix of pairwise distance\n---------------------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(1, figsize=(10, 10))\npl.subplot(2, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Source samples')\n\npl.subplot(2, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.xticks([])\npl.yticks([])\npl.legend(loc=0)\npl.title('Target samples')\n\npl.subplot(2, 2, 3)\npl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Cost matrix - unsupervised DA')\n\npl.subplot(2, 2, 4)\npl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Cost matrix - semisupervised DA')\n\npl.tight_layout()\n\n# the optimal coupling in the semi-supervised DA case will exhibit \" shape\n# similar\" to the cost matrix, (block diagonal matrix)"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Fig 2 : plots optimal couplings for the different methods\n---------------------------------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "pl.figure(2, figsize=(8, 4))\n\npl.subplot(1, 2, 1)\npl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nUnsupervised DA')\n\npl.subplot(1, 2, 2)\npl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')\npl.xticks([])\npl.yticks([])\npl.title('Optimal coupling\\nSemi-supervised DA')\n\npl.tight_layout()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "source": [
+ "Fig 3 : plot transported samples\n--------------------------------\n\n"
+ ],
+ "cell_type": "markdown",
+ "metadata": {}
+ },
+ {
+ "execution_count": null,
+ "cell_type": "code",
+ "source": [
+ "# display transported samples\npl.figure(4, figsize=(8, 4))\npl.subplot(1, 2, 1)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nEmdTransport')\npl.legend(loc=0)\npl.xticks([])\npl.yticks([])\n\npl.subplot(1, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n label='Target samples', alpha=0.5)\npl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,\n marker='+', label='Transp samples', s=30)\npl.title('Transported samples\\nSinkhornTransport')\npl.xticks([])\npl.yticks([])\n\npl.tight_layout()\npl.show()"
+ ],
+ "outputs": [],
+ "metadata": {
+ "collapsed": false
+ }
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 2",
+ "name": "python2",
+ "language": "python"
+ },
+ "language_info": {
+ "mimetype": "text/x-python",
+ "nbconvert_exporter": "python",
+ "name": "python",
+ "file_extension": ".py",
+ "version": "2.7.12",
+ "pygments_lexer": "ipython2",
+ "codemirror_mode": {
+ "version": 2,
+ "name": "ipython"
+ }
+ }
+ }
+} \ No newline at end of file
diff --git a/docs/source/auto_examples/plot_otda_semi_supervised.py b/docs/source/auto_examples/plot_otda_semi_supervised.py
new file mode 100644
index 0000000..7963aef
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_semi_supervised.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+"""
+============================================
+OTDA unsupervised vs semi-supervised setting
+============================================
+
+This example introduces a semi supervised domain adaptation in a 2D setting.
+It explicits the problem of semi supervised domain adaptation and introduces
+some optimal transport approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# Generate data
+# -------------
+
+n_samples_source = 150
+n_samples_target = 150
+
+Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
+Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
+
+
+##############################################################################
+# Transport source samples onto target samples
+# --------------------------------------------
+
+
+# unsupervised domain adaptation
+ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
+transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
+
+# semi-supervised domain adaptation
+ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
+transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
+
+# semi supervised DA uses available labaled target samples to modify the cost
+# matrix involved in the OT problem. The cost of transporting a source sample
+# of class A onto a target sample of class B != A is set to infinite, or a
+# very large value
+
+# note that in the present case we consider that all the target samples are
+# labeled. For daily applications, some target sample might not have labels,
+# in this case the element of yt corresponding to these samples should be
+# filled with -1.
+
+# Warning: we recall that -1 cannot be used as a class label
+
+
+##############################################################################
+# Fig 1 : plots source and target samples + matrix of pairwise distance
+# ---------------------------------------------------------------------
+
+pl.figure(1, figsize=(10, 10))
+pl.subplot(2, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+
+pl.subplot(2, 2, 3)
+pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - unsupervised DA')
+
+pl.subplot(2, 2, 4)
+pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - semisupervised DA')
+
+pl.tight_layout()
+
+# the optimal coupling in the semi-supervised DA case will exhibit " shape
+# similar" to the cost matrix, (block diagonal matrix)
+
+
+##############################################################################
+# Fig 2 : plots optimal couplings for the different methods
+# ---------------------------------------------------------
+
+pl.figure(2, figsize=(8, 4))
+
+pl.subplot(1, 2, 1)
+pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nUnsupervised DA')
+
+pl.subplot(1, 2, 2)
+pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSemi-supervised DA')
+
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 3 : plot transported samples
+# --------------------------------
+
+# display transported samples
+pl.figure(4, figsize=(8, 4))
+pl.subplot(1, 2, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc=0)
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornTransport')
+pl.xticks([])
+pl.yticks([])
+
+pl.tight_layout()
+pl.show()
diff --git a/docs/source/auto_examples/plot_otda_semi_supervised.rst b/docs/source/auto_examples/plot_otda_semi_supervised.rst
new file mode 100644
index 0000000..dc05ed0
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_semi_supervised.rst
@@ -0,0 +1,240 @@
+
+
+.. _sphx_glr_auto_examples_plot_otda_semi_supervised.py:
+
+
+============================================
+OTDA unsupervised vs semi-supervised setting
+============================================
+
+This example introduces a semi supervised domain adaptation in a 2D setting.
+It explicits the problem of semi supervised domain adaptation and introduces
+some optimal transport approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+
+
+
+.. code-block:: python
+
+
+ # Authors: Remi Flamary <remi.flamary@unice.fr>
+ # Stanislas Chambon <stan.chambon@gmail.com>
+ #
+ # License: MIT License
+
+ import matplotlib.pylab as pl
+ import ot
+
+
+
+
+
+
+
+
+Generate data
+-------------
+
+
+
+.. code-block:: python
+
+
+ n_samples_source = 150
+ n_samples_target = 150
+
+ Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
+ Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
+
+
+
+
+
+
+
+
+Transport source samples onto target samples
+--------------------------------------------
+
+
+
+.. code-block:: python
+
+
+
+ # unsupervised domain adaptation
+ ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
+ transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
+
+ # semi-supervised domain adaptation
+ ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
+ ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
+ transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
+
+ # semi supervised DA uses available labaled target samples to modify the cost
+ # matrix involved in the OT problem. The cost of transporting a source sample
+ # of class A onto a target sample of class B != A is set to infinite, or a
+ # very large value
+
+ # note that in the present case we consider that all the target samples are
+ # labeled. For daily applications, some target sample might not have labels,
+ # in this case the element of yt corresponding to these samples should be
+ # filled with -1.
+
+ # Warning: we recall that -1 cannot be used as a class label
+
+
+
+
+
+
+
+
+Fig 1 : plots source and target samples + matrix of pairwise distance
+---------------------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(1, figsize=(10, 10))
+ pl.subplot(2, 2, 1)
+ pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Source samples')
+
+ pl.subplot(2, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+ pl.xticks([])
+ pl.yticks([])
+ pl.legend(loc=0)
+ pl.title('Target samples')
+
+ pl.subplot(2, 2, 3)
+ pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Cost matrix - unsupervised DA')
+
+ pl.subplot(2, 2, 4)
+ pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Cost matrix - semisupervised DA')
+
+ pl.tight_layout()
+
+ # the optimal coupling in the semi-supervised DA case will exhibit " shape
+ # similar" to the cost matrix, (block diagonal matrix)
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_semi_supervised_001.png
+ :align: center
+
+
+
+
+Fig 2 : plots optimal couplings for the different methods
+---------------------------------------------------------
+
+
+
+.. code-block:: python
+
+
+ pl.figure(2, figsize=(8, 4))
+
+ pl.subplot(1, 2, 1)
+ pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nUnsupervised DA')
+
+ pl.subplot(1, 2, 2)
+ pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
+ pl.xticks([])
+ pl.yticks([])
+ pl.title('Optimal coupling\nSemi-supervised DA')
+
+ pl.tight_layout()
+
+
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_semi_supervised_003.png
+ :align: center
+
+
+
+
+Fig 3 : plot transported samples
+--------------------------------
+
+
+
+.. code-block:: python
+
+
+ # display transported samples
+ pl.figure(4, figsize=(8, 4))
+ pl.subplot(1, 2, 1)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nEmdTransport')
+ pl.legend(loc=0)
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.subplot(1, 2, 2)
+ pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+ pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+ pl.title('Transported samples\nSinkhornTransport')
+ pl.xticks([])
+ pl.yticks([])
+
+ pl.tight_layout()
+ pl.show()
+
+
+
+.. image:: /auto_examples/images/sphx_glr_plot_otda_semi_supervised_006.png
+ :align: center
+
+
+
+
+**Total running time of the script:** ( 0 minutes 0.714 seconds)
+
+
+
+.. container:: sphx-glr-footer
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Python source code: plot_otda_semi_supervised.py <plot_otda_semi_supervised.py>`
+
+
+
+ .. container:: sphx-glr-download
+
+ :download:`Download Jupyter notebook: plot_otda_semi_supervised.ipynb <plot_otda_semi_supervised.ipynb>`
+
+.. rst-class:: sphx-glr-signature
+
+ `Generated by Sphinx-Gallery <http://sphinx-gallery.readthedocs.io>`_
diff --git a/docs/source/auto_examples/searchindex b/docs/source/auto_examples/searchindex
new file mode 100644
index 0000000..2cad500
--- /dev/null
+++ b/docs/source/auto_examples/searchindex
Binary files differ
diff --git a/docs/source/conf.py b/docs/source/conf.py
index ff08899..4105d87 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -50,7 +50,7 @@ sys.path.insert(0, os.path.abspath("../.."))
#needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
-# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# extensions coming with Sphinx (named #'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
@@ -62,7 +62,7 @@ extensions = [
'sphinx.ext.ifconfig',
'sphinx.ext.viewcode',
'sphinx.ext.napoleon',
-# 'sphinx_gallery.gen_gallery',
+ #'sphinx_gallery.gen_gallery',
]
# Add any paths that contain templates here, relative to this directory.
@@ -261,7 +261,7 @@ latex_elements = {
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'POT.tex', u'POT Python Optimal Transport library',
- u'Rémi Flamary, Nicolas Courty', 'manual'),
+ author, 'manual'),
]
# The name of an image file (relative to this directory) to place at the top of
@@ -305,7 +305,7 @@ man_pages = [
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'POT', u'POT Python Optimal Transport library Documentation',
- author, 'POT', 'One line description of project.',
+ author, 'POT', 'Python Optimal Transport librar.',
'Miscellaneous'),
]
@@ -326,9 +326,9 @@ texinfo_documents = [
intersphinx_mapping = {'https://docs.python.org/': None}
sphinx_gallery_conf = {
- 'examples_dirs': '../../examples',
+ 'examples_dirs': ['../../examples','../../examples/da'],
'gallery_dirs': 'auto_examples',
- 'mod_example_dir': '../modules/generated/',
+ 'backreferences_dir': '../modules/generated/',
'reference_url': {
'numpy': 'http://docs.scipy.org/doc/numpy-1.9.1',
'scipy': 'http://docs.scipy.org/doc/scipy-0.17.0/reference'}
diff --git a/docs/source/examples.rst b/docs/source/examples.rst
deleted file mode 100644
index f209543..0000000
--- a/docs/source/examples.rst
+++ /dev/null
@@ -1,39 +0,0 @@
-
-
-Examples
-============
-
-1D Optimal transport
----------------------
-
-.. literalinclude:: ../../examples/demo_OT_1D.py
-
-2D Optimal transport on empirical distributions
------------------------------------------------
-
-.. literalinclude:: ../../examples/demo_OT_2D_samples.py
-
-1D Wasserstein barycenter
--------------------------
-
-.. literalinclude:: ../../examples/demo_barycenter_1D.py
-
-OT with user provided regularization
-------------------------------------
-
-.. literalinclude:: ../../examples/demo_optim_OTreg.py
-
-Domain adaptation with optimal transport
-----------------------------------------
-
-.. literalinclude:: ../../examples/demo_OTDA_classes.py
-
-Color transfer in images
-------------------------
-
-.. literalinclude:: ../../examples/demo_OTDA_color_images.py
-
-OT mapping estimation for domain adaptation
--------------------------------------------
-
-.. literalinclude:: ../../examples/demo_OTDA_mapping.py
diff --git a/docs/source/readme.rst b/docs/source/readme.rst
index c1e0017..065093e 100644
--- a/docs/source/readme.rst
+++ b/docs/source/readme.rst
@@ -21,6 +21,7 @@ It provides the following solvers:
- Joint OT matrix and mapping estimation [8].
- Wasserstein Discriminant Analysis [11] (requires autograd +
pymanopt).
+- Gromov-Wasserstein distances and barycenters [12]
Some demonstrations (both in Python and Jupyter Notebook format) are
available in the examples folder.
@@ -150,27 +151,27 @@ Here is a list of the Python notebooks available
want a quick look:
- `1D optimal
- transport <https://github.com/rflamary/POT/blob/master/notebooks/Demo_1D_OT.ipynb>`__
+ transport <https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_1D.ipynb>`__
- `OT Ground
- Loss <https://github.com/rflamary/POT/blob/master/notebooks/Demo_Ground_Loss.ipynb>`__
+ Loss <https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_L1_vs_L2.ipynb>`__
- `Multiple EMD
- computation <https://github.com/rflamary/POT/blob/master/notebooks/Demo_Compute_EMD.ipynb>`__
+ computation <https://github.com/rflamary/POT/blob/master/notebooks/plot_compute_emd.ipynb>`__
- `2D optimal transport on empirical
- distributions <https://github.com/rflamary/POT/blob/master/notebooks/Demo_2D_OT_samples.ipynb>`__
+ distributions <https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_2D_samples.ipynb>`__
- `1D Wasserstein
- barycenter <https://github.com/rflamary/POT/blob/master/notebooks/Demo_1D_barycenter.ipynb>`__
+ barycenter <https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_1D.ipynb>`__
- `OT with user provided
- regularization <https://github.com/rflamary/POT/blob/master/notebooks/Demo_Optim_OTreg.ipynb>`__
+ regularization <https://github.com/rflamary/POT/blob/master/notebooks/plot_optim_OTreg.ipynb>`__
- `Domain adaptation with optimal
- transport <https://github.com/rflamary/POT/blob/master/notebooks/Demo_2D_OT_DomainAdaptation.ipynb>`__
+ transport <https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_d2.ipynb>`__
- `Color transfer in
- images <https://github.com/rflamary/POT/blob/master/notebooks/Demo_Image_ColorAdaptation.ipynb>`__
+ images <https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_color_images.ipynb>`__
- `OT mapping estimation for domain
- adaptation <https://github.com/rflamary/POT/blob/master/notebooks/Demo_2D_OTmapping_DomainAdaptation.ipynb>`__
+ adaptation <https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_mapping.ipynb>`__
- `OT mapping estimation for color transfer in
- images <https://github.com/rflamary/POT/blob/master/notebooks/Demo_Image_ColorAdaptation_mapping.ipynb>`__
+ images <https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_mapping_colors_images.ipynb>`__
- `Wasserstein Discriminant
- Analysis <https://github.com/rflamary/POT/blob/master/notebooks/Demo_Wasserstein_Discriminant_Analysis.ipynb>`__
+ Analysis <https://github.com/rflamary/POT/blob/master/notebooks/plot_WDA.ipynb>`__
You can also see the notebooks with `Jupyter
nbviewer <https://nbviewer.jupyter.org/github/rflamary/POT/tree/master/notebooks/>`__.
@@ -187,6 +188,10 @@ The contributors to this library are:
- `Michael Perrot <http://perso.univ-st-etienne.fr/pem82055/>`__
(Mapping estimation)
- `Léo Gautheron <https://github.com/aje>`__ (GPU implementation)
+- `Nathalie
+ Gayraud <https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1>`__
+- `Stanislas Chambon <https://slasnista.github.io/>`__
+- `Antoine Rolet <https://arolet.github.io/>`__
This toolbox benefit a lot from open source research and we would like
to thank the following persons for providing some code (in various
@@ -196,7 +201,6 @@ languages):
in Matlab)
- `Nicolas Bonneel <http://liris.cnrs.fr/~nbonneel/>`__ ( C++ code for
EMD)
-- `Antoine Rolet <https://arolet.github.io/>`__ ( Mex file for EMD )
- `Marco Cuturi <http://marcocuturi.net/>`__ (Sinkhorn Knopp in
Matlab/Cuda)
@@ -277,6 +281,11 @@ arXiv:1607.05816.
Analysis <https://arxiv.org/pdf/1608.08063.pdf>`__. arXiv preprint
arXiv:1608.08063.
+[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
+`Gromov-Wasserstein averaging of kernel and distance
+matrices <http://proceedings.mlr.press/v48/peyre16.html>`__
+International Conference on Machine Learning (ICML). 2016.
+
.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
.. |Build Status| image:: https://travis-ci.org/rflamary/POT.svg?branch=master