diff options
Diffstat (limited to 'docs/source/auto_examples/plot_otda_jcpot.rst')
-rw-r--r-- | docs/source/auto_examples/plot_otda_jcpot.rst | 336 |
1 files changed, 0 insertions, 336 deletions
diff --git a/docs/source/auto_examples/plot_otda_jcpot.rst b/docs/source/auto_examples/plot_otda_jcpot.rst deleted file mode 100644 index 3433190..0000000 --- a/docs/source/auto_examples/plot_otda_jcpot.rst +++ /dev/null @@ -1,336 +0,0 @@ -.. only:: html - - .. note:: - :class: sphx-glr-download-link-note - - Click :ref:`here <sphx_glr_download_auto_examples_plot_otda_jcpot.py>` to download the full example code - .. rst-class:: sphx-glr-example-title - - .. _sphx_glr_auto_examples_plot_otda_jcpot.py: - - -======================== -OT for multi-source target shift -======================== - -This example introduces a target shift problem with two 2D source and 1 target domain. - - - -.. code-block:: default - - - # Authors: Remi Flamary <remi.flamary@unice.fr> - # Ievgen Redko <ievgen.redko@univ-st-etienne.fr> - # - # License: MIT License - - import pylab as pl - import numpy as np - import ot - from ot.datasets import make_data_classif - - - - - - - - -Generate data -------------- - - -.. code-block:: default - - n = 50 - sigma = 0.3 - np.random.seed(1985) - - p1 = .2 - dec1 = [0, 2] - - p2 = .9 - dec2 = [0, -2] - - pt = .4 - dect = [4, 0] - - xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1) - xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2) - xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect) - - all_Xr = [xs1, xs2] - all_Yr = [ys1, ys2] - - - - - - - - -.. code-block:: default - - - da = 1.5 - - - def plot_ax(dec, name): - pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5) - pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5) - pl.text(dec[0] - .5, dec[1] + 2, name) - - - - - - - - - -Fig 1 : plots source and target samples ---------------------------------------- - - -.. code-block:: default - - - pl.figure(1) - pl.clf() - plot_ax(dec1, 'Source 1') - plot_ax(dec2, 'Source 2') - plot_ax(dect, 'Target') - pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, - label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1)) - pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, - label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2)) - pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, - label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt)) - pl.title('Data') - - pl.legend() - pl.axis('equal') - pl.axis('off') - - - - -.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_001.png - :class: sphx-glr-single-img - - -.. rst-class:: sphx-glr-script-out - - Out: - - .. code-block:: none - - - (-1.85, 5.85, -4.1171725099266725, 4.197384527473105) - - - -Instantiate Sinkhorn transport algorithm and fit them for all source domains ----------------------------------------------------------------------------- - - -.. code-block:: default - - ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean') - - - def print_G(G, xs, ys, xt): - for i in range(G.shape[0]): - for j in range(G.shape[1]): - if G[i, j] > 5e-4: - if ys[i]: - c = 'b' - else: - c = 'r' - pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2) - - - - - - - - - -Fig 2 : plot optimal couplings and transported samples ------------------------------------------------------- - - -.. code-block:: default - - pl.figure(2) - pl.clf() - plot_ax(dec1, 'Source 1') - plot_ax(dec2, 'Source 2') - plot_ax(dect, 'Target') - print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt) - print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt) - pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) - pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) - pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) - - pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') - pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') - - pl.title('Independent OT') - - pl.legend() - pl.axis('equal') - pl.axis('off') - - - - -.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_002.png - :class: sphx-glr-single-img - - -.. rst-class:: sphx-glr-script-out - - Out: - - .. code-block:: none - - - (-1.85, 5.85, -4.11901398007908, 4.201462272227509) - - - -Instantiate JCPOT adaptation algorithm and fit it ----------------------------------------------------------------------------- - - -.. code-block:: default - - otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) - otda.fit(all_Xr, all_Yr, xt) - - ws1 = otda.proportions_.dot(otda.log_['D2'][0]) - ws2 = otda.proportions_.dot(otda.log_['D2'][1]) - - pl.figure(3) - pl.clf() - plot_ax(dec1, 'Source 1') - plot_ax(dec2, 'Source 2') - plot_ax(dect, 'Target') - print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) - print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) - pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) - pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) - pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) - - pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') - pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') - - pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1])) - - pl.legend() - pl.axis('equal') - pl.axis('off') - - - - -.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_003.png - :class: sphx-glr-single-img - - -.. rst-class:: sphx-glr-script-out - - Out: - - .. code-block:: none - - - (-1.85, 5.85, -4.11901398007908, 4.201462272227509) - - - -Run oracle transport algorithm with known proportions ----------------------------------------------------------------------------- - - -.. code-block:: default - - h_res = np.array([1 - pt, pt]) - - ws1 = h_res.dot(otda.log_['D2'][0]) - ws2 = h_res.dot(otda.log_['D2'][1]) - - pl.figure(4) - pl.clf() - plot_ax(dec1, 'Source 1') - plot_ax(dec2, 'Source 2') - plot_ax(dect, 'Target') - print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) - print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) - pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) - pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) - pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) - - pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') - pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') - - pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1])) - - pl.legend() - pl.axis('equal') - pl.axis('off') - pl.show() - - - -.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_004.png - :class: sphx-glr-single-img - - -.. rst-class:: sphx-glr-script-out - - Out: - - .. code-block:: none - - /home/rflamary/PYTHON/POT/examples/plot_otda_jcpot.py:171: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure. - pl.show() - - - - - -.. rst-class:: sphx-glr-timing - - **Total running time of the script:** ( 0 minutes 4.725 seconds) - - -.. _sphx_glr_download_auto_examples_plot_otda_jcpot.py: - - -.. only :: html - - .. container:: sphx-glr-footer - :class: sphx-glr-footer-example - - - - .. container:: sphx-glr-download sphx-glr-download-python - - :download:`Download Python source code: plot_otda_jcpot.py <plot_otda_jcpot.py>` - - - - .. container:: sphx-glr-download sphx-glr-download-jupyter - - :download:`Download Jupyter notebook: plot_otda_jcpot.ipynb <plot_otda_jcpot.ipynb>` - - -.. only:: html - - .. rst-class:: sphx-glr-signature - - `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_ |