diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2020-04-20 16:01:15 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2020-04-20 16:01:15 +0200 |
commit | 6ac8d405f16832e671c432d7b03ce3da38f8fedc (patch) | |
tree | 0cf1920d6f751fa14c46a791e52c16ce464efcd3 /docs/source/auto_examples/plot_otda_jcpot.rst | |
parent | 45d232f6c49bf485192953001ae81cb46d97652e (diff) |
add all pages in documentation
Diffstat (limited to 'docs/source/auto_examples/plot_otda_jcpot.rst')
-rw-r--r-- | docs/source/auto_examples/plot_otda_jcpot.rst | 336 |
1 files changed, 336 insertions, 0 deletions
diff --git a/docs/source/auto_examples/plot_otda_jcpot.rst b/docs/source/auto_examples/plot_otda_jcpot.rst new file mode 100644 index 0000000..3433190 --- /dev/null +++ b/docs/source/auto_examples/plot_otda_jcpot.rst @@ -0,0 +1,336 @@ +.. only:: html + + .. note:: + :class: sphx-glr-download-link-note + + Click :ref:`here <sphx_glr_download_auto_examples_plot_otda_jcpot.py>` to download the full example code + .. rst-class:: sphx-glr-example-title + + .. _sphx_glr_auto_examples_plot_otda_jcpot.py: + + +======================== +OT for multi-source target shift +======================== + +This example introduces a target shift problem with two 2D source and 1 target domain. + + + +.. code-block:: default + + + # Authors: Remi Flamary <remi.flamary@unice.fr> + # Ievgen Redko <ievgen.redko@univ-st-etienne.fr> + # + # License: MIT License + + import pylab as pl + import numpy as np + import ot + from ot.datasets import make_data_classif + + + + + + + + +Generate data +------------- + + +.. code-block:: default + + n = 50 + sigma = 0.3 + np.random.seed(1985) + + p1 = .2 + dec1 = [0, 2] + + p2 = .9 + dec2 = [0, -2] + + pt = .4 + dect = [4, 0] + + xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1) + xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2) + xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect) + + all_Xr = [xs1, xs2] + all_Yr = [ys1, ys2] + + + + + + + + +.. code-block:: default + + + da = 1.5 + + + def plot_ax(dec, name): + pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5) + pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5) + pl.text(dec[0] - .5, dec[1] + 2, name) + + + + + + + + + +Fig 1 : plots source and target samples +--------------------------------------- + + +.. code-block:: default + + + pl.figure(1) + pl.clf() + plot_ax(dec1, 'Source 1') + plot_ax(dec2, 'Source 2') + plot_ax(dect, 'Target') + pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, + label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1)) + pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, + label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2)) + pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, + label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt)) + pl.title('Data') + + pl.legend() + pl.axis('equal') + pl.axis('off') + + + + +.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_001.png + :class: sphx-glr-single-img + + +.. rst-class:: sphx-glr-script-out + + Out: + + .. code-block:: none + + + (-1.85, 5.85, -4.1171725099266725, 4.197384527473105) + + + +Instantiate Sinkhorn transport algorithm and fit them for all source domains +---------------------------------------------------------------------------- + + +.. code-block:: default + + ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean') + + + def print_G(G, xs, ys, xt): + for i in range(G.shape[0]): + for j in range(G.shape[1]): + if G[i, j] > 5e-4: + if ys[i]: + c = 'b' + else: + c = 'r' + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2) + + + + + + + + + +Fig 2 : plot optimal couplings and transported samples +------------------------------------------------------ + + +.. code-block:: default + + pl.figure(2) + pl.clf() + plot_ax(dec1, 'Source 1') + plot_ax(dec2, 'Source 2') + plot_ax(dect, 'Target') + print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt) + print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt) + pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) + pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) + pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + + pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') + pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + + pl.title('Independent OT') + + pl.legend() + pl.axis('equal') + pl.axis('off') + + + + +.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_002.png + :class: sphx-glr-single-img + + +.. rst-class:: sphx-glr-script-out + + Out: + + .. code-block:: none + + + (-1.85, 5.85, -4.11901398007908, 4.201462272227509) + + + +Instantiate JCPOT adaptation algorithm and fit it +---------------------------------------------------------------------------- + + +.. code-block:: default + + otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) + otda.fit(all_Xr, all_Yr, xt) + + ws1 = otda.proportions_.dot(otda.log_['D2'][0]) + ws2 = otda.proportions_.dot(otda.log_['D2'][1]) + + pl.figure(3) + pl.clf() + plot_ax(dec1, 'Source 1') + plot_ax(dec2, 'Source 2') + plot_ax(dect, 'Target') + print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) + print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) + pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) + pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) + pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + + pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') + pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + + pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1])) + + pl.legend() + pl.axis('equal') + pl.axis('off') + + + + +.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_003.png + :class: sphx-glr-single-img + + +.. rst-class:: sphx-glr-script-out + + Out: + + .. code-block:: none + + + (-1.85, 5.85, -4.11901398007908, 4.201462272227509) + + + +Run oracle transport algorithm with known proportions +---------------------------------------------------------------------------- + + +.. code-block:: default + + h_res = np.array([1 - pt, pt]) + + ws1 = h_res.dot(otda.log_['D2'][0]) + ws2 = h_res.dot(otda.log_['D2'][1]) + + pl.figure(4) + pl.clf() + plot_ax(dec1, 'Source 1') + plot_ax(dec2, 'Source 2') + plot_ax(dect, 'Target') + print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) + print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) + pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) + pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) + pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + + pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') + pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + + pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1])) + + pl.legend() + pl.axis('equal') + pl.axis('off') + pl.show() + + + +.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_004.png + :class: sphx-glr-single-img + + +.. rst-class:: sphx-glr-script-out + + Out: + + .. code-block:: none + + /home/rflamary/PYTHON/POT/examples/plot_otda_jcpot.py:171: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure. + pl.show() + + + + + +.. rst-class:: sphx-glr-timing + + **Total running time of the script:** ( 0 minutes 4.725 seconds) + + +.. _sphx_glr_download_auto_examples_plot_otda_jcpot.py: + + +.. only :: html + + .. container:: sphx-glr-footer + :class: sphx-glr-footer-example + + + + .. container:: sphx-glr-download sphx-glr-download-python + + :download:`Download Python source code: plot_otda_jcpot.py <plot_otda_jcpot.py>` + + + + .. container:: sphx-glr-download sphx-glr-download-jupyter + + :download:`Download Jupyter notebook: plot_otda_jcpot.ipynb <plot_otda_jcpot.ipynb>` + + +.. only:: html + + .. rst-class:: sphx-glr-signature + + `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_ |