{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n# OT for image color adaptation\n\n\nThis example presents a way of transferring colors between two images\nwith Optimal Transport as introduced in [6]\n\n[6] Ferradans, S., Papadakis, N., Peyre, G., & Aujol, J. F. (2014).\nRegularized discrete optimal transport.\nSIAM Journal on Imaging Sciences, 7(3), 1853-1882.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: Remi Flamary \n# Stanislas Chambon \n#\n# License: MIT License\n\nimport numpy as np\nfrom scipy import ndimage\nimport matplotlib.pylab as pl\nimport ot\n\n\nr = np.random.RandomState(42)\n\n\ndef im2mat(I):\n \"\"\"Converts an image to matrix (one pixel per line)\"\"\"\n return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))\n\n\ndef mat2im(X, shape):\n \"\"\"Converts back a matrix to an image\"\"\"\n return X.reshape(shape)\n\n\ndef minmax(I):\n return np.clip(I, 0, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate data\n-------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Loading images\nI1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256\nI2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256\n\nX1 = im2mat(I1)\nX2 = im2mat(I2)\n\n# training samples\nnb = 1000\nidx1 = r.randint(X1.shape[0], size=(nb,))\nidx2 = r.randint(X2.shape[0], size=(nb,))\n\nXs = X1[idx1, :]\nXt = X2[idx2, :]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot original image\n-------------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1, figsize=(6.4, 3))\n\npl.subplot(1, 2, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.imshow(I2)\npl.axis('off')\npl.title('Image 2')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Scatter plot of colors\n----------------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2, figsize=(6.4, 3))\n\npl.subplot(1, 2, 1)\npl.scatter(Xs[:, 0], Xs[:, 2], c=Xs)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 1')\n\npl.subplot(1, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 2], c=Xt)\npl.axis([0, 1, 0, 1])\npl.xlabel('Red')\npl.ylabel('Blue')\npl.title('Image 2')\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# EMDTransport\not_emd = ot.da.EMDTransport()\not_emd.fit(Xs=Xs, Xt=Xt)\n\n# SinkhornTransport\not_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1)\not_sinkhorn.fit(Xs=Xs, Xt=Xt)\n\n# prediction between images (using out of sample prediction as in [6])\ntransp_Xs_emd = ot_emd.transform(Xs=X1)\ntransp_Xt_emd = ot_emd.inverse_transform(Xt=X2)\n\ntransp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=X1)\ntransp_Xt_sinkhorn = ot_sinkhorn.inverse_transform(Xt=X2)\n\nI1t = minmax(mat2im(transp_Xs_emd, I1.shape))\nI2t = minmax(mat2im(transp_Xt_emd, I2.shape))\n\nI1te = minmax(mat2im(transp_Xs_sinkhorn, I1.shape))\nI2te = minmax(mat2im(transp_Xt_sinkhorn, I2.shape))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot new images\n---------------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(3, figsize=(8, 4))\n\npl.subplot(2, 3, 1)\npl.imshow(I1)\npl.axis('off')\npl.title('Image 1')\n\npl.subplot(2, 3, 2)\npl.imshow(I1t)\npl.axis('off')\npl.title('Image 1 Adapt')\n\npl.subplot(2, 3, 3)\npl.imshow(I1te)\npl.axis('off')\npl.title('Image 1 Adapt (reg)')\n\npl.subplot(2, 3, 4)\npl.imshow(I2)\npl.axis('off')\npl.title('Image 2')\n\npl.subplot(2, 3, 5)\npl.imshow(I2t)\npl.axis('off')\npl.title('Image 2 Adapt')\n\npl.subplot(2, 3, 6)\npl.imshow(I2te)\npl.axis('off')\npl.title('Image 2 Adapt (reg)')\npl.tight_layout()\n\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 0 }