summaryrefslogtreecommitdiff
path: root/docs/source/auto_examples/plot_otda_jcpot.rst
diff options
context:
space:
mode:
Diffstat (limited to 'docs/source/auto_examples/plot_otda_jcpot.rst')
-rw-r--r--docs/source/auto_examples/plot_otda_jcpot.rst336
1 files changed, 0 insertions, 336 deletions
diff --git a/docs/source/auto_examples/plot_otda_jcpot.rst b/docs/source/auto_examples/plot_otda_jcpot.rst
deleted file mode 100644
index 3433190..0000000
--- a/docs/source/auto_examples/plot_otda_jcpot.rst
+++ /dev/null
@@ -1,336 +0,0 @@
-.. only:: html
-
- .. note::
- :class: sphx-glr-download-link-note
-
- Click :ref:`here <sphx_glr_download_auto_examples_plot_otda_jcpot.py>` to download the full example code
- .. rst-class:: sphx-glr-example-title
-
- .. _sphx_glr_auto_examples_plot_otda_jcpot.py:
-
-
-========================
-OT for multi-source target shift
-========================
-
-This example introduces a target shift problem with two 2D source and 1 target domain.
-
-
-
-.. code-block:: default
-
-
- # Authors: Remi Flamary <remi.flamary@unice.fr>
- # Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
- #
- # License: MIT License
-
- import pylab as pl
- import numpy as np
- import ot
- from ot.datasets import make_data_classif
-
-
-
-
-
-
-
-
-Generate data
--------------
-
-
-.. code-block:: default
-
- n = 50
- sigma = 0.3
- np.random.seed(1985)
-
- p1 = .2
- dec1 = [0, 2]
-
- p2 = .9
- dec2 = [0, -2]
-
- pt = .4
- dect = [4, 0]
-
- xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1)
- xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2)
- xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect)
-
- all_Xr = [xs1, xs2]
- all_Yr = [ys1, ys2]
-
-
-
-
-
-
-
-
-.. code-block:: default
-
-
- da = 1.5
-
-
- def plot_ax(dec, name):
- pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)
- pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)
- pl.text(dec[0] - .5, dec[1] + 2, name)
-
-
-
-
-
-
-
-
-
-Fig 1 : plots source and target samples
----------------------------------------
-
-
-.. code-block:: default
-
-
- pl.figure(1)
- pl.clf()
- plot_ax(dec1, 'Source 1')
- plot_ax(dec2, 'Source 2')
- plot_ax(dect, 'Target')
- pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9,
- label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1))
- pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9,
- label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2))
- pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9,
- label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt))
- pl.title('Data')
-
- pl.legend()
- pl.axis('equal')
- pl.axis('off')
-
-
-
-
-.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_001.png
- :class: sphx-glr-single-img
-
-
-.. rst-class:: sphx-glr-script-out
-
- Out:
-
- .. code-block:: none
-
-
- (-1.85, 5.85, -4.1171725099266725, 4.197384527473105)
-
-
-
-Instantiate Sinkhorn transport algorithm and fit them for all source domains
-----------------------------------------------------------------------------
-
-
-.. code-block:: default
-
- ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean')
-
-
- def print_G(G, xs, ys, xt):
- for i in range(G.shape[0]):
- for j in range(G.shape[1]):
- if G[i, j] > 5e-4:
- if ys[i]:
- c = 'b'
- else:
- c = 'r'
- pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2)
-
-
-
-
-
-
-
-
-
-Fig 2 : plot optimal couplings and transported samples
-------------------------------------------------------
-
-
-.. code-block:: default
-
- pl.figure(2)
- pl.clf()
- plot_ax(dec1, 'Source 1')
- plot_ax(dec2, 'Source 2')
- plot_ax(dect, 'Target')
- print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt)
- print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt)
- pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
- pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
- pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
-
- pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
- pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
-
- pl.title('Independent OT')
-
- pl.legend()
- pl.axis('equal')
- pl.axis('off')
-
-
-
-
-.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_002.png
- :class: sphx-glr-single-img
-
-
-.. rst-class:: sphx-glr-script-out
-
- Out:
-
- .. code-block:: none
-
-
- (-1.85, 5.85, -4.11901398007908, 4.201462272227509)
-
-
-
-Instantiate JCPOT adaptation algorithm and fit it
-----------------------------------------------------------------------------
-
-
-.. code-block:: default
-
- otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
- otda.fit(all_Xr, all_Yr, xt)
-
- ws1 = otda.proportions_.dot(otda.log_['D2'][0])
- ws2 = otda.proportions_.dot(otda.log_['D2'][1])
-
- pl.figure(3)
- pl.clf()
- plot_ax(dec1, 'Source 1')
- plot_ax(dec2, 'Source 2')
- plot_ax(dect, 'Target')
- print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
- print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
- pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
- pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
- pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
-
- pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
- pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
-
- pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1]))
-
- pl.legend()
- pl.axis('equal')
- pl.axis('off')
-
-
-
-
-.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_003.png
- :class: sphx-glr-single-img
-
-
-.. rst-class:: sphx-glr-script-out
-
- Out:
-
- .. code-block:: none
-
-
- (-1.85, 5.85, -4.11901398007908, 4.201462272227509)
-
-
-
-Run oracle transport algorithm with known proportions
-----------------------------------------------------------------------------
-
-
-.. code-block:: default
-
- h_res = np.array([1 - pt, pt])
-
- ws1 = h_res.dot(otda.log_['D2'][0])
- ws2 = h_res.dot(otda.log_['D2'][1])
-
- pl.figure(4)
- pl.clf()
- plot_ax(dec1, 'Source 1')
- plot_ax(dec2, 'Source 2')
- plot_ax(dect, 'Target')
- print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt)
- print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt)
- pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
- pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
- pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
-
- pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
- pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
-
- pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1]))
-
- pl.legend()
- pl.axis('equal')
- pl.axis('off')
- pl.show()
-
-
-
-.. image:: /auto_examples/images/sphx_glr_plot_otda_jcpot_004.png
- :class: sphx-glr-single-img
-
-
-.. rst-class:: sphx-glr-script-out
-
- Out:
-
- .. code-block:: none
-
- /home/rflamary/PYTHON/POT/examples/plot_otda_jcpot.py:171: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
- pl.show()
-
-
-
-
-
-.. rst-class:: sphx-glr-timing
-
- **Total running time of the script:** ( 0 minutes 4.725 seconds)
-
-
-.. _sphx_glr_download_auto_examples_plot_otda_jcpot.py:
-
-
-.. only :: html
-
- .. container:: sphx-glr-footer
- :class: sphx-glr-footer-example
-
-
-
- .. container:: sphx-glr-download sphx-glr-download-python
-
- :download:`Download Python source code: plot_otda_jcpot.py <plot_otda_jcpot.py>`
-
-
-
- .. container:: sphx-glr-download sphx-glr-download-jupyter
-
- :download:`Download Jupyter notebook: plot_otda_jcpot.ipynb <plot_otda_jcpot.ipynb>`
-
-
-.. only:: html
-
- .. rst-class:: sphx-glr-signature
-
- `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_