summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2017-09-14 16:52:29 +0200
committerGitHub <noreply@github.com>2017-09-14 16:52:29 +0200
commita9427a1dd3fd9e99b0fee349eb714c93ac6faf83 (patch)
tree7d0c5d7336ac6f69630bdc61b69b002524c9eb85 /examples
parent2324b1f629ebea5b6b58d0b32082032f2914f7e3 (diff)
parente70d5420204db78691af2d0fbe04cc3d4416a8f4 (diff)
Merge branch 'master' into autonb
Diffstat (limited to 'examples')
-rw-r--r--examples/da/plot_otda_semi_supervised.py147
-rw-r--r--examples/plot_gromov.py90
-rwxr-xr-xexamples/plot_gromov_barycenter.py248
3 files changed, 485 insertions, 0 deletions
diff --git a/examples/da/plot_otda_semi_supervised.py b/examples/da/plot_otda_semi_supervised.py
new file mode 100644
index 0000000..8095c4d
--- /dev/null
+++ b/examples/da/plot_otda_semi_supervised.py
@@ -0,0 +1,147 @@
+# -*- coding: utf-8 -*-
+"""
+============================================
+OTDA unsupervised vs semi-supervised setting
+============================================
+
+This example introduces a semi supervised domain adaptation in a 2D setting.
+It explicits the problem of semi supervised domain adaptation and introduces
+some optimal transport approaches to solve it.
+
+Quantities such as optimal couplings, greater coupling coefficients and
+transported samples are represented in order to give a visual understanding
+of what the transport methods are doing.
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Stanislas Chambon <stan.chambon@gmail.com>
+#
+# License: MIT License
+
+import matplotlib.pylab as pl
+import ot
+
+
+##############################################################################
+# generate data
+##############################################################################
+
+n_samples_source = 150
+n_samples_target = 150
+
+Xs, ys = ot.datasets.get_data_classif('3gauss', n_samples_source)
+Xt, yt = ot.datasets.get_data_classif('3gauss2', n_samples_target)
+
+
+##############################################################################
+# Transport source samples onto target samples
+##############################################################################
+
+# unsupervised domain adaptation
+ot_sinkhorn_un = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_un.fit(Xs=Xs, Xt=Xt)
+transp_Xs_sinkhorn_un = ot_sinkhorn_un.transform(Xs=Xs)
+
+# semi-supervised domain adaptation
+ot_sinkhorn_semi = ot.da.SinkhornTransport(reg_e=1e-1)
+ot_sinkhorn_semi.fit(Xs=Xs, Xt=Xt, ys=ys, yt=yt)
+transp_Xs_sinkhorn_semi = ot_sinkhorn_semi.transform(Xs=Xs)
+
+# semi supervised DA uses available labaled target samples to modify the cost
+# matrix involved in the OT problem. The cost of transporting a source sample
+# of class A onto a target sample of class B != A is set to infinite, or a
+# very large value
+
+# note that in the present case we consider that all the target samples are
+# labeled. For daily applications, some target sample might not have labels,
+# in this case the element of yt corresponding to these samples should be
+# filled with -1.
+
+# Warning: we recall that -1 cannot be used as a class label
+
+
+##############################################################################
+# Fig 1 : plots source and target samples + matrix of pairwise distance
+##############################################################################
+
+pl.figure(1, figsize=(10, 10))
+pl.subplot(2, 2, 1)
+pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Source samples')
+
+pl.subplot(2, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples')
+pl.xticks([])
+pl.yticks([])
+pl.legend(loc=0)
+pl.title('Target samples')
+
+pl.subplot(2, 2, 3)
+pl.imshow(ot_sinkhorn_un.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - unsupervised DA')
+
+pl.subplot(2, 2, 4)
+pl.imshow(ot_sinkhorn_semi.cost_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Cost matrix - semisupervised DA')
+
+pl.tight_layout()
+
+# the optimal coupling in the semi-supervised DA case will exhibit " shape
+# similar" to the cost matrix, (block diagonal matrix)
+
+
+##############################################################################
+# Fig 2 : plots optimal couplings for the different methods
+##############################################################################
+
+pl.figure(2, figsize=(8, 4))
+
+pl.subplot(1, 2, 1)
+pl.imshow(ot_sinkhorn_un.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nUnsupervised DA')
+
+pl.subplot(1, 2, 2)
+pl.imshow(ot_sinkhorn_semi.coupling_, interpolation='nearest')
+pl.xticks([])
+pl.yticks([])
+pl.title('Optimal coupling\nSemi-supervised DA')
+
+pl.tight_layout()
+
+
+##############################################################################
+# Fig 3 : plot transported samples
+##############################################################################
+
+# display transported samples
+pl.figure(4, figsize=(8, 4))
+pl.subplot(1, 2, 1)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_un[:, 0], transp_Xs_sinkhorn_un[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nEmdTransport')
+pl.legend(loc=0)
+pl.xticks([])
+pl.yticks([])
+
+pl.subplot(1, 2, 2)
+pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o',
+ label='Target samples', alpha=0.5)
+pl.scatter(transp_Xs_sinkhorn_semi[:, 0], transp_Xs_sinkhorn_semi[:, 1], c=ys,
+ marker='+', label='Transp samples', s=30)
+pl.title('Transported samples\nSinkhornTransport')
+pl.xticks([])
+pl.yticks([])
+
+pl.tight_layout()
+pl.show()
diff --git a/examples/plot_gromov.py b/examples/plot_gromov.py
new file mode 100644
index 0000000..dce66c4
--- /dev/null
+++ b/examples/plot_gromov.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+"""
+==========================
+Gromov-Wasserstein example
+==========================
+This example is designed to show how to use the Gromov-Wassertsein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+import scipy as sp
+import numpy as np
+import matplotlib.pylab as pl
+
+import ot
+
+
+"""
+Sample two Gaussian distributions (2D and 3D)
+=============================================
+The Gromov-Wasserstein distance allows to compute distances with samples that
+do not belong to the same metric space. For demonstration purpose, we sample
+two Gaussian distributions in 2- and 3-dimensional spaces.
+"""
+
+n_samples = 30 # nb samples
+
+mu_s = np.array([0, 0])
+cov_s = np.array([[1, 0], [0, 1]])
+
+mu_t = np.array([4, 4, 4])
+cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+
+xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s)
+P = sp.linalg.sqrtm(cov_t)
+xt = np.random.randn(n_samples, 3).dot(P) + mu_t
+
+
+"""
+Plotting the distributions
+==========================
+"""
+fig = pl.figure()
+ax1 = fig.add_subplot(121)
+ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
+ax2 = fig.add_subplot(122, projection='3d')
+ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
+pl.show()
+
+
+"""
+Compute distance kernels, normalize them and then display
+=========================================================
+"""
+
+C1 = sp.spatial.distance.cdist(xs, xs)
+C2 = sp.spatial.distance.cdist(xt, xt)
+
+C1 /= C1.max()
+C2 /= C2.max()
+
+pl.figure()
+pl.subplot(121)
+pl.imshow(C1)
+pl.subplot(122)
+pl.imshow(C2)
+pl.show()
+
+"""
+Compute Gromov-Wasserstein plans and distance
+=============================================
+"""
+
+p = ot.unif(n_samples)
+q = ot.unif(n_samples)
+
+gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
+
+print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
+
+pl.figure()
+pl.imshow(gw, cmap='jet')
+pl.colorbar()
+pl.show()
diff --git a/examples/plot_gromov_barycenter.py b/examples/plot_gromov_barycenter.py
new file mode 100755
index 0000000..52f4966
--- /dev/null
+++ b/examples/plot_gromov_barycenter.py
@@ -0,0 +1,248 @@
+# -*- coding: utf-8 -*-
+"""
+=====================================
+Gromov-Wasserstein Barycenter example
+=====================================
+This example is designed to show how to use the Gromov-Wasserstein distance
+computation in POT.
+"""
+
+# Author: Erwan Vautier <erwan.vautier@gmail.com>
+# Nicolas Courty <ncourty@irisa.fr>
+#
+# License: MIT License
+
+
+import numpy as np
+import scipy as sp
+
+import scipy.ndimage as spi
+import matplotlib.pylab as pl
+from sklearn import manifold
+from sklearn.decomposition import PCA
+
+import ot
+
+"""
+
+Smacof MDS
+==========
+This function allows to find an embedding of points given a dissimilarity matrix
+that will be given by the output of the algorithm
+"""
+
+
+def smacof_mds(C, dim, max_iter=3000, eps=1e-9):
+ """
+ Returns an interpolated point cloud following the dissimilarity matrix C
+ using SMACOF multidimensional scaling (MDS) in specific dimensionned
+ target space
+
+ Parameters
+ ----------
+ C : ndarray, shape (ns, ns)
+ dissimilarity matrix
+ dim : int
+ dimension of the targeted space
+ max_iter : int
+ Maximum number of iterations of the SMACOF algorithm for a single run
+ eps : float
+ relative tolerance w.r.t stress to declare converge
+
+ Returns
+ -------
+ npos : ndarray, shape (R, dim)
+ Embedded coordinates of the interpolated point cloud (defined with
+ one isometry)
+ """
+
+ rng = np.random.RandomState(seed=3)
+
+ mds = manifold.MDS(
+ dim,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity='precomputed',
+ n_init=1)
+ pos = mds.fit(C).embedding_
+
+ nmds = manifold.MDS(
+ 2,
+ max_iter=max_iter,
+ eps=1e-9,
+ dissimilarity="precomputed",
+ random_state=rng,
+ n_init=1)
+ npos = nmds.fit_transform(C, init=pos)
+
+ return npos
+
+
+"""
+Data preparation
+================
+The four distributions are constructed from 4 simple images
+"""
+
+
+def im2mat(I):
+ """Converts and image to matrix (one pixel per line)"""
+ return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
+
+
+square = spi.imread('../data/square.png').astype(np.float64)[:, :, 2] / 256
+cross = spi.imread('../data/cross.png').astype(np.float64)[:, :, 2] / 256
+triangle = spi.imread('../data/triangle.png').astype(np.float64)[:, :, 2] / 256
+star = spi.imread('../data/star.png').astype(np.float64)[:, :, 2] / 256
+
+shapes = [square, cross, triangle, star]
+
+S = 4
+xs = [[] for i in range(S)]
+
+
+for nb in range(4):
+ for i in range(8):
+ for j in range(8):
+ if shapes[nb][i, j] < 0.95:
+ xs[nb].append([j, 8 - i])
+
+xs = np.array([np.array(xs[0]), np.array(xs[1]),
+ np.array(xs[2]), np.array(xs[3])])
+
+
+"""
+Barycenter computation
+======================
+The four distributions are constructed from 4 simple images
+"""
+ns = [len(xs[s]) for s in range(S)]
+n_samples = 30
+
+"""Compute all distances matrices for the four shapes"""
+Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
+Cs = [cs / cs.max() for cs in Cs]
+
+ps = [ot.unif(ns[s]) for s in range(S)]
+p = ot.unif(n_samples)
+
+
+lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
+
+Ct01 = [0 for i in range(2)]
+for i in range(2):
+ Ct01[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[1]],
+ [ps[0], ps[1]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, stopThr=1e-3)
+
+Ct02 = [0 for i in range(2)]
+for i in range(2):
+ Ct02[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[0], Cs[2]],
+ [ps[0], ps[2]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, stopThr=1e-3)
+
+Ct13 = [0 for i in range(2)]
+for i in range(2):
+ Ct13[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[1], Cs[3]],
+ [ps[1], ps[3]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, stopThr=1e-3)
+
+Ct23 = [0 for i in range(2)]
+for i in range(2):
+ Ct23[i] = ot.gromov.gromov_barycenters(n_samples, [Cs[2], Cs[3]],
+ [ps[2], ps[3]
+ ], p, lambdast[i], 'square_loss', 5e-4,
+ max_iter=100, stopThr=1e-3)
+
+"""
+Visualization
+=============
+"""
+
+"""The PCA helps in getting consistency between the rotations"""
+
+clf = PCA(n_components=2)
+npos = [0, 0, 0, 0]
+npos = [smacof_mds(Cs[s], 2) for s in range(S)]
+
+npost01 = [0, 0]
+npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
+npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
+
+npost02 = [0, 0]
+npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
+npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
+
+npost13 = [0, 0]
+npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
+npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
+
+npost23 = [0, 0]
+npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
+npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
+
+
+fig = pl.figure(figsize=(10, 10))
+
+ax1 = pl.subplot2grid((4, 4), (0, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
+
+ax2 = pl.subplot2grid((4, 4), (0, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
+
+ax3 = pl.subplot2grid((4, 4), (0, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
+
+ax4 = pl.subplot2grid((4, 4), (0, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
+
+ax5 = pl.subplot2grid((4, 4), (1, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
+
+ax6 = pl.subplot2grid((4, 4), (1, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
+
+ax7 = pl.subplot2grid((4, 4), (2, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
+
+ax8 = pl.subplot2grid((4, 4), (2, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
+
+ax9 = pl.subplot2grid((4, 4), (3, 0))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
+
+ax10 = pl.subplot2grid((4, 4), (3, 1))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
+
+ax11 = pl.subplot2grid((4, 4), (3, 2))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
+
+ax12 = pl.subplot2grid((4, 4), (3, 3))
+pl.xlim((-1, 1))
+pl.ylim((-1, 1))
+ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')