summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_OT_1D_smooth.ipynb
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2020-04-20 15:19:09 +0200
committerRémi Flamary <remi.flamary@gmail.com>2020-04-20 15:19:09 +0200
commite65606ae498bd611f6a994868c2a66dfbea403cd (patch)
treeb9b43dcaf8499b6d57b806ce04350fb6b792537f /docs/source/auto_examples/plot_OT_1D_smooth.ipynb
parent8acaf262baa04a4d2bdd9c774c45c5bb2fb2d12a (diff)
big update examples
Diffstat (limited to 'docs/source/auto_examples/plot_OT_1D_smooth.ipynb')
-rw-r--r--docs/source/auto_examples/plot_OT_1D_smooth.ipynb34
1 files changed, 28 insertions, 6 deletions
diff --git a/docs/source/auto_examples/plot_OT_1D_smooth.ipynb b/docs/source/auto_examples/plot_OT_1D_smooth.ipynb
index d523f6a..493e6bb 100644
--- a/docs/source/auto_examples/plot_OT_1D_smooth.ipynb
+++ b/docs/source/auto_examples/plot_OT_1D_smooth.ipynb
@@ -44,7 +44,7 @@
},
"outputs": [],
"source": [
- "#%% parameters\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
+ "n = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()"
]
},
{
@@ -62,7 +62,18 @@
},
"outputs": [],
"source": [
- "#%% plot the distributions\n\npl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()\n\n#%% plot distributions and loss matrix\n\npl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')"
+ "pl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, 'b', label='Source distribution')\npl.plot(x, b, 'r', label='Target distribution')\npl.legend()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, 'Cost matrix M')"
]
},
{
@@ -80,7 +91,7 @@
},
"outputs": [],
"source": [
- "#%% EMD\n\nG0 = ot.emd(a, b, M)\n\npl.figure(3, figsize=(5, 5))\not.plot.plot1D_mat(a, b, G0, 'OT matrix G0')"
+ "G0 = ot.emd(a, b, M)\n\npl.figure(3, figsize=(5, 5))\not.plot.plot1D_mat(a, b, G0, 'OT matrix G0')"
]
},
{
@@ -98,7 +109,7 @@
},
"outputs": [],
"source": [
- "#%% Sinkhorn\n\nlambd = 2e-3\nGs = ot.sinkhorn(a, b, M, lambd, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')\n\npl.show()"
+ "lambd = 2e-3\nGs = ot.sinkhorn(a, b, M, lambd, verbose=True)\n\npl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')\n\npl.show()"
]
},
{
@@ -116,7 +127,18 @@
},
"outputs": [],
"source": [
- "#%% Smooth OT with KL regularization\n\nlambd = 2e-3\nGsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')\n\npl.figure(5, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')\n\npl.show()\n\n\n#%% Smooth OT with KL regularization\n\nlambd = 1e-1\nGsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')\n\npl.figure(6, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')\n\npl.show()"
+ "lambd = 2e-3\nGsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')\n\npl.figure(5, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')\n\npl.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "lambd = 1e-1\nGsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')\n\npl.figure(6, figsize=(5, 5))\not.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')\n\npl.show()"
]
}
],
@@ -136,7 +158,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.5"
+ "version": "3.6.9"
}
},
"nbformat": 4,