diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2017-08-29 14:12:29 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-08-29 14:12:29 +0200 |
commit | a2ec6e55e458c719484e86a4e6a6e764c2e38dc8 (patch) | |
tree | 1b973cb5314af46a28060c477840903bf3fbf4ac /examples/da/plot_otda_d2.py | |
parent | 7638d019b43e52d17600cac653939e7cd807478c (diff) | |
parent | 65de6fc9add57b95b8968e1e75fe1af342f81d01 (diff) |
Merge pull request #22 from Slasnista/domain_adaptation
Fixes #17
Diffstat (limited to 'examples/da/plot_otda_d2.py')
-rw-r--r-- | examples/da/plot_otda_d2.py | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/examples/da/plot_otda_d2.py b/examples/da/plot_otda_d2.py new file mode 100644 index 0000000..3daa0a6 --- /dev/null +++ b/examples/da/plot_otda_d2.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +""" +============================== +OT for empirical distributions +============================== + +This example introduces a domain adaptation in a 2D setting. It explicits +the problem of domain adaptation and introduces some optimal transport +approaches to solve it. + +Quantities such as optimal couplings, greater coupling coefficients and +transported samples are represented in order to give a visual understanding +of what the transport methods are doing. +""" + +# Authors: Remi Flamary <remi.flamary@unice.fr> +# Stanislas Chambon <stan.chambon@gmail.com> +# +# License: MIT License + +import matplotlib.pylab as pl +import ot + + +############################################################################## +# generate data +############################################################################## + +n_samples_source = 150 +n_samples_target = 150 + +Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source) +Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target) + +# Cost matrix +M = ot.dist(Xs, Xt, metric='sqeuclidean') + + +############################################################################## +# Instantiate the different transport algorithms and fit them +############################################################################## + +# EMD Transport +ot_emd = ot.da.EMDTransport() +ot_emd.fit(Xs=Xs, Xt=Xt) + +# Sinkhorn Transport +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1) +ot_sinkhorn.fit(Xs=Xs, Xt=Xt) + +# Sinkhorn Transport with Group lasso regularization +ot_lpl1 = ot.da.SinkhornLpl1Transport(reg_e=1e-1, reg_cl=1e0) +ot_lpl1.fit(Xs=Xs, ys=ys, Xt=Xt) + +# transport source samples onto target samples +transp_Xs_emd = ot_emd.transform(Xs=Xs) +transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) +transp_Xs_lpl1 = ot_lpl1.transform(Xs=Xs) + + +############################################################################## +# Fig 1 : plots source and target samples + matrix of pairwise distance +############################################################################## + +pl.figure(1, figsize=(10, 10)) +pl.subplot(2, 2, 1) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.xticks([]) +pl.yticks([]) +pl.legend(loc=0) +pl.title('Source samples') + +pl.subplot(2, 2, 2) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.xticks([]) +pl.yticks([]) +pl.legend(loc=0) +pl.title('Target samples') + +pl.subplot(2, 2, 3) +pl.imshow(M, interpolation='nearest') +pl.xticks([]) +pl.yticks([]) +pl.title('Matrix of pairwise distances') +pl.tight_layout() + + +############################################################################## +# Fig 2 : plots optimal couplings for the different methods +############################################################################## + +pl.figure(2, figsize=(10, 6)) + +pl.subplot(2, 3, 1) +pl.imshow(ot_emd.coupling_, interpolation='nearest') +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nEMDTransport') + +pl.subplot(2, 3, 2) +pl.imshow(ot_sinkhorn.coupling_, interpolation='nearest') +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nSinkhornTransport') + +pl.subplot(2, 3, 3) +pl.imshow(ot_lpl1.coupling_, interpolation='nearest') +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nSinkhornLpl1Transport') + +pl.subplot(2, 3, 4) +ot.plot.plot2D_samples_mat(Xs, Xt, ot_emd.coupling_, c=[.5, .5, 1]) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.xticks([]) +pl.yticks([]) +pl.title('Main coupling coefficients\nEMDTransport') + +pl.subplot(2, 3, 5) +ot.plot.plot2D_samples_mat(Xs, Xt, ot_sinkhorn.coupling_, c=[.5, .5, 1]) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.xticks([]) +pl.yticks([]) +pl.title('Main coupling coefficients\nSinkhornTransport') + +pl.subplot(2, 3, 6) +ot.plot.plot2D_samples_mat(Xs, Xt, ot_lpl1.coupling_, c=[.5, .5, 1]) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.xticks([]) +pl.yticks([]) +pl.title('Main coupling coefficients\nSinkhornLpl1Transport') +pl.tight_layout() + + +############################################################################## +# Fig 3 : plot transported samples +############################################################################## + +# display transported samples +pl.figure(4, figsize=(10, 4)) +pl.subplot(1, 3, 1) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.5) +pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.title('Transported samples\nEmdTransport') +pl.legend(loc=0) +pl.xticks([]) +pl.yticks([]) + +pl.subplot(1, 3, 2) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.5) +pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.title('Transported samples\nSinkhornTransport') +pl.xticks([]) +pl.yticks([]) + +pl.subplot(1, 3, 3) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.5) +pl.scatter(transp_Xs_lpl1[:, 0], transp_Xs_lpl1[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.title('Transported samples\nSinkhornLpl1Transport') +pl.xticks([]) +pl.yticks([]) + +pl.tight_layout() +pl.show() |