summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_otda_mapping.ipynb
blob: 898466d5d9fffacafb0a03361065c884c887f9b0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# OT mapping estimation for domain adaptation\n\n\nThis example presents how to use MappingTransport to estimate at the same\ntime both the coupling transport and approximate the transport map with either\na linear or a kernelized mapping as introduced in [8].\n\n[8] M. Perrot, N. Courty, R. Flamary, A. Habrard,\n    \"Mapping estimation for discrete optimal transport\",\n    Neural Information Processing Systems (NIPS), 2016.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Authors: Remi Flamary <remi.flamary@unice.fr>\n#          Stanislas Chambon <stan.chambon@gmail.com>\n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Generate data\n-------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n_source_samples = 100\nn_target_samples = 100\ntheta = 2 * np.pi / 20\nnoise_level = 0.1\n\nXs, ys = ot.datasets.make_data_classif(\n    'gaussrot', n_source_samples, nz=noise_level)\nXs_new, _ = ot.datasets.make_data_classif(\n    'gaussrot', n_source_samples, nz=noise_level)\nXt, yt = ot.datasets.make_data_classif(\n    'gaussrot', n_target_samples, theta=theta, nz=noise_level)\n\n# one of the target mode changes its variance (no linear mapping)\nXt[yt == 2] *= 3\nXt = Xt + 4"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot data\n---------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pl.figure(1, (10, 5))\npl.clf()\npl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')\npl.legend(loc=0)\npl.title('Source and target distributions')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Instantiate the different transport algorithms and fit them\n-----------------------------------------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# MappingTransport with linear kernel\not_mapping_linear = ot.da.MappingTransport(\n    kernel=\"linear\", mu=1e0, eta=1e-8, bias=True,\n    max_iter=20, verbose=True)\n\not_mapping_linear.fit(Xs=Xs, Xt=Xt)\n\n# for original source samples, transform applies barycentric mapping\ntransp_Xs_linear = ot_mapping_linear.transform(Xs=Xs)\n\n# for out of source samples, transform applies the linear mapping\ntransp_Xs_linear_new = ot_mapping_linear.transform(Xs=Xs_new)\n\n\n# MappingTransport with gaussian kernel\not_mapping_gaussian = ot.da.MappingTransport(\n    kernel=\"gaussian\", eta=1e-5, mu=1e-1, bias=True, sigma=1,\n    max_iter=10, verbose=True)\not_mapping_gaussian.fit(Xs=Xs, Xt=Xt)\n\n# for original source samples, transform applies barycentric mapping\ntransp_Xs_gaussian = ot_mapping_gaussian.transform(Xs=Xs)\n\n# for out of source samples, transform applies the gaussian mapping\ntransp_Xs_gaussian_new = ot_mapping_gaussian.transform(Xs=Xs_new)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Plot transported samples\n------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pl.figure(2)\npl.clf()\npl.subplot(2, 2, 1)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n           label='Target samples', alpha=.2)\npl.scatter(transp_Xs_linear[:, 0], transp_Xs_linear[:, 1], c=ys, marker='+',\n           label='Mapped source samples')\npl.title(\"Bary. mapping (linear)\")\npl.legend(loc=0)\n\npl.subplot(2, 2, 2)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n           label='Target samples', alpha=.2)\npl.scatter(transp_Xs_linear_new[:, 0], transp_Xs_linear_new[:, 1],\n           c=ys, marker='+', label='Learned mapping')\npl.title(\"Estim. mapping (linear)\")\n\npl.subplot(2, 2, 3)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n           label='Target samples', alpha=.2)\npl.scatter(transp_Xs_gaussian[:, 0], transp_Xs_gaussian[:, 1], c=ys,\n           marker='+', label='barycentric mapping')\npl.title(\"Bary. mapping (kernel)\")\n\npl.subplot(2, 2, 4)\npl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',\n           label='Target samples', alpha=.2)\npl.scatter(transp_Xs_gaussian_new[:, 0], transp_Xs_gaussian_new[:, 1], c=ys,\n           marker='+', label='Learned mapping')\npl.title(\"Estim. mapping (kernel)\")\npl.tight_layout()\n\npl.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.6.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}