summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_otda_jcpot.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_otda_jcpot.ipynb')
-rw-r--r--docs/source/auto_examples/plot_otda_jcpot.ipynb173
1 files changed, 173 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_otda_jcpot.ipynb b/docs/source/auto_examples/plot_otda_jcpot.ipynb
new file mode 100644
index 0000000..a81d47a
--- /dev/null
+++ b/docs/source/auto_examples/plot_otda_jcpot.ipynb
@@ -0,0 +1,173 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n# OT for multi-source target shift\n\n\nThis example introduces a target shift problem with two 2D source and 1 target domain.\n\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "# Authors: Remi Flamary <remi.flamary@unice.fr>\n# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>\n#\n# License: MIT License\n\nimport pylab as pl\nimport numpy as np\nimport ot\nfrom ot.datasets import make_data_classif"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Generate data\n-------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "n = 50\nsigma = 0.3\nnp.random.seed(1985)\n\np1 = .2\ndec1 = [0, 2]\n\np2 = .9\ndec2 = [0, -2]\n\npt = .4\ndect = [4, 0]\n\nxs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1)\nxs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2)\nxt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect)\n\nall_Xr = [xs1, xs2]\nall_Yr = [ys1, ys2]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "da = 1.5\n\n\ndef plot_ax(dec, name):\n pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)\n pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)\n pl.text(dec[0] - .5, dec[1] + 2, name)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 1 : plots source and target samples\n---------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(1)\npl.clf()\nplot_ax(dec1, 'Source 1')\nplot_ax(dec2, 'Source 2')\nplot_ax(dect, 'Target')\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9,\n label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1))\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9,\n label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2))\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9,\n label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt))\npl.title('Data')\n\npl.legend()\npl.axis('equal')\npl.axis('off')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate Sinkhorn transport algorithm and fit them for all source domains\n----------------------------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean')\n\n\ndef print_G(G, xs, ys, xt):\n for i in range(G.shape[0]):\n for j in range(G.shape[1]):\n if G[i, j] > 5e-4:\n if ys[i]:\n c = 'b'\n else:\n c = 'r'\n pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Fig 2 : plot optimal couplings and transported samples\n------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "pl.figure(2)\npl.clf()\nplot_ax(dec1, 'Source 1')\nplot_ax(dec2, 'Source 2')\nplot_ax(dect, 'Target')\nprint_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt)\nprint_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt)\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)\n\npl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')\npl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')\n\npl.title('Independent OT')\n\npl.legend()\npl.axis('equal')\npl.axis('off')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Instantiate JCPOT adaptation algorithm and fit it\n----------------------------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)\notda.fit(all_Xr, all_Yr, xt)\n\nws1 = otda.proportions_.dot(otda.log_['D2'][0])\nws2 = otda.proportions_.dot(otda.log_['D2'][1])\n\npl.figure(3)\npl.clf()\nplot_ax(dec1, 'Source 1')\nplot_ax(dec2, 'Source 2')\nplot_ax(dect, 'Target')\nprint_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)\nprint_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)\n\npl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')\npl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')\n\npl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1]))\n\npl.legend()\npl.axis('equal')\npl.axis('off')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Run oracle transport algorithm with known proportions\n----------------------------------------------------------------------------\n\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "collapsed": false
+ },
+ "outputs": [],
+ "source": [
+ "h_res = np.array([1 - pt, pt])\n\nws1 = h_res.dot(otda.log_['D2'][0])\nws2 = h_res.dot(otda.log_['D2'][1])\n\npl.figure(4)\npl.clf()\nplot_ax(dec1, 'Source 1')\nplot_ax(dec2, 'Source 2')\nplot_ax(dect, 'Target')\nprint_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)\nprint_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)\npl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)\npl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)\npl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)\n\npl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')\npl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')\n\npl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1]))\n\npl.legend()\npl.axis('equal')\npl.axis('off')\npl.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+} \ No newline at end of file