From f79f4831e9d70217ceab2758733abe62dc93208b Mon Sep 17 00:00:00 2001 From: Slasnista Date: Mon, 28 Aug 2017 11:03:28 +0200 Subject: handling input arguments in fit, transform... methods + remove old examples --- examples/plot_OTDA_classes.py | 117 ------------------------------------------ 1 file changed, 117 deletions(-) delete mode 100644 examples/plot_OTDA_classes.py (limited to 'examples/plot_OTDA_classes.py') diff --git a/examples/plot_OTDA_classes.py b/examples/plot_OTDA_classes.py deleted file mode 100644 index 53e4bae..0000000 --- a/examples/plot_OTDA_classes.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- -""" -======================== -OT for domain adaptation -======================== - -""" - -# Author: Remi Flamary -# -# License: MIT License - -import matplotlib.pylab as pl -import ot - - -#%% parameters - -n = 150 # nb samples in source and target datasets - -xs, ys = ot.datasets.get_data_classif('3gauss', n) -xt, yt = ot.datasets.get_data_classif('3gauss2', n) - - -#%% plot samples - -pl.figure(1, figsize=(6.4, 3)) - -pl.subplot(1, 2, 1) -pl.scatter(xs[:, 0], xs[:, 1], c=ys, marker='+', label='Source samples') -pl.legend(loc=0) -pl.title('Source distributions') - -pl.subplot(1, 2, 2) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', label='Target samples') -pl.legend(loc=0) -pl.title('target distributions') - - -#%% OT estimation - -# LP problem -da_emd = ot.da.OTDA() # init class -da_emd.fit(xs, xt) # fit distributions -xst0 = da_emd.interp() # interpolation of source samples - -# sinkhorn regularization -lambd = 1e-1 -da_entrop = ot.da.OTDA_sinkhorn() -da_entrop.fit(xs, xt, reg=lambd) -xsts = da_entrop.interp() - -# non-convex Group lasso regularization -reg = 1e-1 -eta = 1e0 -da_lpl1 = ot.da.OTDA_lpl1() -da_lpl1.fit(xs, ys, xt, reg=reg, eta=eta) -xstg = da_lpl1.interp() - -# True Group lasso regularization -reg = 1e-1 -eta = 2e0 -da_l1l2 = ot.da.OTDA_l1l2() -da_l1l2.fit(xs, ys, xt, reg=reg, eta=eta, numItermax=20, verbose=True) -xstgl = da_l1l2.interp() - -#%% plot interpolated source samples - -param_img = {'interpolation': 'nearest', 'cmap': 'spectral'} - -pl.figure(2, figsize=(8, 4.5)) -pl.subplot(2, 4, 1) -pl.imshow(da_emd.G, **param_img) -pl.title('OT matrix') - -pl.subplot(2, 4, 2) -pl.imshow(da_entrop.G, **param_img) -pl.title('OT matrix\nsinkhorn') - -pl.subplot(2, 4, 3) -pl.imshow(da_lpl1.G, **param_img) -pl.title('OT matrix\nnon-convex Group Lasso') - -pl.subplot(2, 4, 4) -pl.imshow(da_l1l2.G, **param_img) -pl.title('OT matrix\nGroup Lasso') - -pl.subplot(2, 4, 5) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(xst0[:, 0], xst0[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.title('Interp samples') -pl.legend(loc=0) - -pl.subplot(2, 4, 6) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(xsts[:, 0], xsts[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.title('Interp samples\nSinkhorn') - -pl.subplot(2, 4, 7) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(xstg[:, 0], xstg[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.title('Interp samples\nnon-convex Group Lasso') - -pl.subplot(2, 4, 8) -pl.scatter(xt[:, 0], xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(xstgl[:, 0], xstgl[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.title('Interp samples\nGroup Lasso') -pl.tight_layout() -pl.show() -- cgit v1.2.3