{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Stochastic examples\n\n\nThis example is designed to show how to use the stochatic optimization\nalgorithms for descrete and semicontinous measures from the POT library.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Kilian Fatras \n#\n# License: MIT License\n\nimport matplotlib.pylab as pl\nimport numpy as np\nimport ot\nimport ot.plot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "COMPUTE TRANSPORTATION MATRIX FOR SEMI-DUAL PROBLEM\n############################################################################\n############################################################################\n DISCRETE CASE:\n\n Sample two discrete measures for the discrete case\n ---------------------------------------------\n\n Define 2 discrete measures a and b, the points where are defined the source\n and the target measures and finally the cost matrix c.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_source = 7\nn_target = 4\nreg = 1\nnumItermax = 1000\n\na = ot.utils.unif(n_source)\nb = ot.utils.unif(n_target)\n\nrng = np.random.RandomState(0)\nX_source = rng.randn(n_source, 2)\nY_target = rng.randn(n_target, 2)\nM = ot.dist(X_source, Y_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Call the \"SAG\" method to find the transportation matrix in the discrete case\n---------------------------------------------\n\nDefine the method \"SAG\", call ot.solve_semi_dual_entropic and plot the\nresults.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "method = \"SAG\"\nsag_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,\n numItermax)\nprint(sag_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SEMICONTINOUS CASE:\n\nSample one general measure a, one discrete measures b for the semicontinous\ncase\n---------------------------------------------\n\nDefine one general measure a, one discrete measures b, the points where\nare defined the source and the target measures and finally the cost matrix c.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_source = 7\nn_target = 4\nreg = 1\nnumItermax = 1000\nlog = True\n\na = ot.utils.unif(n_source)\nb = ot.utils.unif(n_target)\n\nrng = np.random.RandomState(0)\nX_source = rng.randn(n_source, 2)\nY_target = rng.randn(n_target, 2)\nM = ot.dist(X_source, Y_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Call the \"ASGD\" method to find the transportation matrix in the semicontinous\ncase\n---------------------------------------------\n\nDefine the method \"ASGD\", call ot.solve_semi_dual_entropic and plot the\nresults.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "method = \"ASGD\"\nasgd_pi, log_asgd = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method,\n numItermax, log=log)\nprint(log_asgd['alpha'], log_asgd['beta'])\nprint(asgd_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compare the results with the Sinkhorn algorithm\n---------------------------------------------\n\nCall the Sinkhorn algorithm from POT\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sinkhorn_pi = ot.sinkhorn(a, b, M, reg)\nprint(sinkhorn_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PLOT TRANSPORTATION MATRIX\n#############################################################################\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot SAG results\n----------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sag_pi, 'semi-dual : OT matrix SAG')\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot ASGD results\n-----------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, asgd_pi, 'semi-dual : OT matrix ASGD')\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot Sinkhorn results\n---------------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "COMPUTE TRANSPORTATION MATRIX FOR DUAL PROBLEM\n############################################################################\n############################################################################\n SEMICONTINOUS CASE:\n\n Sample one general measure a, one discrete measures b for the semicontinous\n case\n ---------------------------------------------\n\n Define one general measure a, one discrete measures b, the points where\n are defined the source and the target measures and finally the cost matrix c.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_source = 7\nn_target = 4\nreg = 1\nnumItermax = 100000\nlr = 0.1\nbatch_size = 3\nlog = True\n\na = ot.utils.unif(n_source)\nb = ot.utils.unif(n_target)\n\nrng = np.random.RandomState(0)\nX_source = rng.randn(n_source, 2)\nY_target = rng.randn(n_target, 2)\nM = ot.dist(X_source, Y_target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Call the \"SGD\" dual method to find the transportation matrix in the\nsemicontinous case\n---------------------------------------------\n\nCall ot.solve_dual_entropic and plot the results.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sgd_dual_pi, log_sgd = ot.stochastic.solve_dual_entropic(a, b, M, reg,\n batch_size, numItermax,\n lr, log=log)\nprint(log_sgd['alpha'], log_sgd['beta'])\nprint(sgd_dual_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compare the results with the Sinkhorn algorithm\n---------------------------------------------\n\nCall the Sinkhorn algorithm from POT\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sinkhorn_pi = ot.sinkhorn(a, b, M, reg)\nprint(sinkhorn_pi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot SGD results\n-----------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sgd_dual_pi, 'dual : OT matrix SGD')\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot Sinkhorn results\n---------------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, figsize=(5, 5))\not.plot.plot1D_mat(a, b, sinkhorn_pi, 'OT matrix Sinkhorn')\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 0 }