diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2020-04-20 22:04:03 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2020-04-20 22:04:03 +0200 |
commit | 21949bbc3469234f88972bdfe973f68eb9e62794 (patch) | |
tree | 6bc93db587bd80d0ccb9e33596c4526aeaefec4c | |
parent | d54184c233cd211a693e4cdf4b25dd68b07ed00b (diff) | |
parent | 43b2190db71b1ccbeec8fddaae23ca6af220e1b5 (diff) |
Merge branch 'master' into doc_travis
-rw-r--r-- | README.md | 45 | ||||
-rw-r--r-- | examples/plot_otda_laplacian.py | 127 | ||||
-rw-r--r-- | ot/da.py | 253 | ||||
-rw-r--r-- | ot/utils.py | 6 | ||||
-rw-r--r-- | test/test_da.py | 64 |
5 files changed, 473 insertions, 22 deletions
@@ -2,11 +2,11 @@ [![PyPI version](https://badge.fury.io/py/POT.svg)](https://badge.fury.io/py/POT) [![Anaconda Cloud](https://anaconda.org/conda-forge/pot/badges/version.svg)](https://anaconda.org/conda-forge/pot) -[![Build Status](https://travis-ci.org/rflamary/POT.svg?branch=master)](https://travis-ci.org/rflamary/POT) +[![Build Status](https://travis-ci.org/rflamary/POT.svg?branch=master)](https://travis-ci.org/PythonOT/POT) [![Documentation Status](https://readthedocs.org/projects/pot/badge/?version=latest)](http://pot.readthedocs.io/en/latest/?badge=latest) [![Downloads](https://pepy.tech/badge/pot)](https://pepy.tech/project/pot) [![Anaconda downloads](https://anaconda.org/conda-forge/pot/badges/downloads.svg)](https://anaconda.org/conda-forge/pot) -[![License](https://anaconda.org/conda-forge/pot/badges/license.svg)](https://github.com/rflamary/POT/blob/master/LICENSE) +[![License](https://anaconda.org/conda-forge/pot/badges/license.svg)](https://github.com/PythonOT/POT/blob/master/LICENSE) @@ -20,7 +20,7 @@ It provides the following solvers: * Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized Wasserstein barycenters [16] with LP solver (only small scale). * Bregman projections for Wasserstein barycenter [3], convolutional barycenter [21] and unmixing [4]. -* Optimal transport for domain adaptation with group lasso regularization [5] +* Optimal transport for domain adaptation with group lasso regularization and Laplacian regularization [5][30] * Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. * Linear OT [14] and Joint OT matrix and mapping estimation [8]. * Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt). @@ -140,26 +140,26 @@ ba=ot.barycenter(A,M,reg) # reg is regularization parameter The examples folder contain several examples and use case for the library. The full documentation is available on [Readthedocs](http://pot.readthedocs.io/). -Here is a list of the Python notebooks available [here](https://github.com/rflamary/POT/blob/master/notebooks/) if you want a quick look: +Here is a list of the Python notebooks available [here](https://github.com/PythonOT/POT/blob/master/notebooks/) if you want a quick look: -* [1D optimal transport](https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_1D.ipynb) -* [OT Ground Loss](https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_L1_vs_L2.ipynb) -* [Multiple EMD computation](https://github.com/rflamary/POT/blob/master/notebooks/plot_compute_emd.ipynb) -* [2D optimal transport on empirical distributions](https://github.com/rflamary/POT/blob/master/notebooks/plot_OT_2D_samples.ipynb) -* [1D Wasserstein barycenter](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_1D.ipynb) -* [OT with user provided regularization](https://github.com/rflamary/POT/blob/master/notebooks/plot_optim_OTreg.ipynb) -* [Domain adaptation with optimal transport](https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_d2.ipynb) -* [Color transfer in images](https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_color_images.ipynb) -* [OT mapping estimation for domain adaptation](https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_mapping.ipynb) -* [OT mapping estimation for color transfer in images](https://github.com/rflamary/POT/blob/master/notebooks/plot_otda_mapping_colors_images.ipynb) -* [Wasserstein Discriminant Analysis](https://github.com/rflamary/POT/blob/master/notebooks/plot_WDA.ipynb) -* [Gromov Wasserstein](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov.ipynb) -* [Gromov Wasserstein Barycenter](https://github.com/rflamary/POT/blob/master/notebooks/plot_gromov_barycenter.ipynb) -* [Fused Gromov Wasserstein](https://github.com/rflamary/POT/blob/master/notebooks/plot_fgw.ipynb) -* [Fused Gromov Wasserstein Barycenter](https://github.com/rflamary/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb) +* [1D optimal transport](https://github.com/PythonOT/POT/blob/master/notebooks/plot_OT_1D.ipynb) +* [OT Ground Loss](https://github.com/PythonOT/POT/blob/master/notebooks/plot_OT_L1_vs_L2.ipynb) +* [Multiple EMD computation](https://github.com/PythonOT/POT/blob/master/notebooks/plot_compute_emd.ipynb) +* [2D optimal transport on empirical distributions](https://github.com/PythonOT/POT/blob/master/notebooks/plot_OT_2D_samples.ipynb) +* [1D Wasserstein barycenter](https://github.com/PythonOT/POT/blob/master/notebooks/plot_barycenter_1D.ipynb) +* [OT with user provided regularization](https://github.com/PythonOT/POT/blob/master/notebooks/plot_optim_OTreg.ipynb) +* [Domain adaptation with optimal transport](https://github.com/PythonOT/POT/blob/master/notebooks/plot_otda_d2.ipynb) +* [Color transfer in images](https://github.com/PythonOT/POT/blob/master/notebooks/plot_otda_color_images.ipynb) +* [OT mapping estimation for domain adaptation](https://github.com/PythonOT/POT/blob/master/notebooks/plot_otda_mapping.ipynb) +* [OT mapping estimation for color transfer in images](https://github.com/PythonOT/POT/blob/master/notebooks/plot_otda_mapping_colors_images.ipynb) +* [Wasserstein Discriminant Analysis](https://github.com/PythonOT/POT/blob/master/notebooks/plot_WDA.ipynb) +* [Gromov Wasserstein](https://github.com/PythonOT/POT/blob/master/notebooks/plot_gromov.ipynb) +* [Gromov Wasserstein Barycenter](https://github.com/PythonOT/POT/blob/master/notebooks/plot_gromov_barycenter.ipynb) +* [Fused Gromov Wasserstein](https://github.com/PythonOT/POT/blob/master/notebooks/plot_fgw.ipynb) +* [Fused Gromov Wasserstein Barycenter](https://github.com/PythonOT/POT/blob/master/notebooks/plot_barycenter_fgw.ipynb) -You can also see the notebooks with [Jupyter nbviewer](https://nbviewer.jupyter.org/github/rflamary/POT/tree/master/notebooks/). +You can also see the notebooks with [Jupyter nbviewer](https://nbviewer.jupyter.org/github/PythonOT/POT/tree/master/notebooks/). ## Acknowledgements @@ -184,6 +184,7 @@ The contributors to this library are * [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT) * [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) +* [Ievgen Redko](https://ievred.github.io/) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -264,4 +265,6 @@ You can also post bug reports and feature requests in Github issues. Make sure t [28] Caffarelli, L. A., McCann, R. J. (2020). [Free boundaries in optimal transport and Monge-Ampere obstacle problems](http://www.math.toronto.edu/~mccann/papers/annals2010.pdf), Annals of mathematics, 673-730. -[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276.
\ No newline at end of file +[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276. + +[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. diff --git a/examples/plot_otda_laplacian.py b/examples/plot_otda_laplacian.py new file mode 100644 index 0000000..67c8f67 --- /dev/null +++ b/examples/plot_otda_laplacian.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" +====================================================== +OT with Laplacian regularization for domain adaptation +====================================================== + +This example introduces a domain adaptation in a 2D setting and OTDA +approach with Laplacian regularization. + +""" + +# Authors: Ievgen Redko <ievgen.redko@univ-st-etienne.fr> + +# License: MIT License + +import matplotlib.pylab as pl +import ot + +############################################################################## +# Generate data +# ------------- + +n_source_samples = 150 +n_target_samples = 150 + +Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples) +Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples) + + +############################################################################## +# Instantiate the different transport algorithms and fit them +# ----------------------------------------------------------- + +# EMD Transport +ot_emd = ot.da.EMDTransport() +ot_emd.fit(Xs=Xs, Xt=Xt) + +# Sinkhorn Transport +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.01) +ot_sinkhorn.fit(Xs=Xs, Xt=Xt) + +# EMD Transport with Laplacian regularization +ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1) +ot_emd_laplace.fit(Xs=Xs, Xt=Xt) + +# transport source samples onto target samples +transp_Xs_emd = ot_emd.transform(Xs=Xs) +transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) +transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs) + +############################################################################## +# Fig 1 : plots source and target samples +# --------------------------------------- + +pl.figure(1, figsize=(10, 5)) +pl.subplot(1, 2, 1) +pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') +pl.xticks([]) +pl.yticks([]) +pl.legend(loc=0) +pl.title('Source samples') + +pl.subplot(1, 2, 2) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') +pl.xticks([]) +pl.yticks([]) +pl.legend(loc=0) +pl.title('Target samples') +pl.tight_layout() + + +############################################################################## +# Fig 2 : plot optimal couplings and transported samples +# ------------------------------------------------------ + +param_img = {'interpolation': 'nearest'} + +pl.figure(2, figsize=(15, 8)) +pl.subplot(2, 3, 1) +pl.imshow(ot_emd.coupling_, **param_img) +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nEMDTransport') + +pl.figure(2, figsize=(15, 8)) +pl.subplot(2, 3, 2) +pl.imshow(ot_sinkhorn.coupling_, **param_img) +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nSinkhornTransport') + +pl.subplot(2, 3, 3) +pl.imshow(ot_emd_laplace.coupling_, **param_img) +pl.xticks([]) +pl.yticks([]) +pl.title('Optimal coupling\nEMDLaplaceTransport') + +pl.subplot(2, 3, 4) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.3) +pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.xticks([]) +pl.yticks([]) +pl.title('Transported samples\nEmdTransport') +pl.legend(loc="lower left") + +pl.subplot(2, 3, 5) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.3) +pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.xticks([]) +pl.yticks([]) +pl.title('Transported samples\nSinkhornTransport') + +pl.subplot(2, 3, 6) +pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', + label='Target samples', alpha=0.3) +pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys, + marker='+', label='Transp samples', s=30) +pl.xticks([]) +pl.yticks([]) +pl.title('Transported samples\nEMDLaplaceTransport') +pl.tight_layout() + +pl.show() @@ -16,7 +16,7 @@ import scipy.linalg as linalg from .bregman import sinkhorn, jcpot_barycenter from .lp import emd -from .utils import unif, dist, kernel, cost_normalization, label_normalization +from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots from .utils import check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg @@ -748,6 +748,139 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, return A, b +def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5, + numItermax=100, stopThr=1e-9, numInnerItermax=100000, + stopInnerThr=1e-9, log=False, verbose=False): + r"""Solve the optimal transport problem (OT) with Laplacian regularization + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + eta\Omega_\alpha(\gamma) + + s.t.\ \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + + where: + + - a and b are source and target weights (sum to 1) + - xs and xt are source and target samples + - M is the (ns,nt) metric cost matrix + - :math:`\Omega_\alpha` is the Laplacian regularization term + :math:`\Omega_\alpha = (1-\alpha)/n_s^2\sum_{i,j}S^s_{i,j}\|T(\mathbf{x}^s_i)-T(\mathbf{x}^s_j)\|^2+\alpha/n_t^2\sum_{i,j}S^t_{i,j}^'\|T(\mathbf{x}^t_i)-T(\mathbf{x}^t_j)\|^2` + with :math:`S^s_{i,j}, S^t_{i,j}` denoting source and target similarity matrices and :math:`T(\cdot)` being a barycentric mapping + + The algorithm used for solving the problem is the conditional gradient algorithm as proposed in [5]. + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + sim : string, optional + Type of similarity ('knn' or 'gauss') used to construct the Laplacian. + sim_param : int or float, optional + Parameter (number of the nearest neighbors for sim='knn' + or bandwidth for sim='gauss') used to compute the Laplacian. + reg : string + Type of Laplacian regularization + eta : float + Regularization term for Laplacian regularization + alpha : float + Regularization term for source domain's importance in regularization + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (inner emd solver) (>0) + numInnerItermax : int, optional + Max number of iterations (inner CG solver) + stopInnerThr : float, optional + Stop threshold on error (inner CG solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE + Transactions on Pattern Analysis and Machine Intelligence , + vol.PP, no.99, pp.1-1 + .. [30] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy, + "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching," + in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + if not isinstance(sim_param, (int, float, type(None))): + raise ValueError( + 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__)) + + if sim == 'gauss': + if sim_param is None: + sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sS = kernel(xs, xs, method=sim, sigma=sim_param) + sT = kernel(xt, xt, method=sim, sigma=sim_param) + + elif sim == 'knn': + if sim_param is None: + sim_param = 3 + + from sklearn.neighbors import kneighbors_graph + + sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray() + sS = (sS + sS.T) / 2 + sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray() + sT = (sT + sT.T) / 2 + else: + raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim)) + + lS = laplacian(sS) + lT = laplacian(sT) + + def f(G): + return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \ + + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs))))) + + ls2 = lS + lS.T + lt2 = lT + lT.T + xt2 = np.dot(xt, xt.T) + + if reg == 'disp': + Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T) + Ct = -eta * (1 - alpha) / xt.shape[0] * dots(xs, xt.T, lt2) + M = M + Cs + Ct + + def df(G): + return alpha * np.dot(ls2, np.dot(G, xt2))\ + + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2))) + + return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax, + stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log) + + def distribution_estimation_uniform(X): """estimates a uniform distribution from an array of samples X @@ -1576,6 +1709,124 @@ class SinkhornLpl1Transport(BaseTransport): return self +class EMDLaplaceTransport(BaseTransport): + + """Domain Adapatation OT method based on Earth Mover's Distance with Laplacian regularization + + Parameters + ---------- + reg_type : string optional (default='pos') + Type of the regularization term: 'pos' and 'disp' for + regularization term defined in [2] and [6], respectively. + reg_lap : float, optional (default=1) + Laplacian regularization parameter + reg_src : float, optional (default=0.5) + Source relative importance in regularization + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + similarity : string, optional (default="knn") + The similarity to use either knn or gaussian + similarity_param : int or float, optional (default=None) + Parameter for the similarity: number of nearest neighbors or bandwidth + if similarity="knn" or "gaussian", respectively. If None is provided, + it is set to 3 or the average pairwise squared Euclidean distance, respectively. + max_iter : int, optional (default=100) + Max number of BCD iterations + tol : float, optional (default=1e-5) + Stop threshold on relative loss decrease (>0) + max_inner_iter : int, optional (default=10) + Max number of iterations (inner CG solver) + inner_tol : float, optional (default=1e-6) + Stop threshold on error (inner CG solver) (>0) + log : int, optional (default=False) + Controls the logs of the optimization algorithm + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + + References + ---------- + .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE Transactions + on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + .. [2] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy, + "Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching," + in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + """ + + def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean", + norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9, + max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans'): + self.reg = reg_type + self.reg_lap = reg_lap + self.reg_src = reg_src + self.metric = metric + self.norm = norm + self.similarity = similarity + self.sim_param = similarity_param + self.max_iter = max_iter + self.tol = tol + self.max_inner_iter = max_inner_iter + self.inner_tol = inner_tol + self.log = log + self.verbose = verbose + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt) + + returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_, + xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap, + alpha=self.reg_src, numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, + stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose) + + # coupling estimation + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + return self + + class SinkhornL1l2Transport(BaseTransport): """Domain Adapatation OT method based on sinkhorn algorithm + diff --git a/ot/utils.py b/ot/utils.py index c154f99..f9911a1 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -49,6 +49,12 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): return K +def laplacian(x): + """Compute Laplacian matrix""" + L = np.diag(np.sum(x, axis=0)) - x + return L + + def unif(n): """ return a uniform histogram of length n (simplex) diff --git a/test/test_da.py b/test/test_da.py index 7d0fdda..3b28119 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -689,3 +689,67 @@ def test_jcpot_barycenter(): numItermax=10000, stopThr=1e-9, verbose=False, log=False) np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) + + +def test_emd_laplace_class(): + """test_emd_laplace_transport + """ + ns = 150 + nt = 200 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=True) + + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + + assert hasattr(otda, "coupling_") + assert hasattr(otda, "log_") + + # test dimensions of coupling + assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) + + # test all margin constraints + mu_s = unif(ns) + mu_t = unif(nt) + + assert_allclose( + np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) + assert_allclose( + np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) + + # test transform + transp_Xs = otda.transform(Xs=Xs) + [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] + + Xs_new, _ = make_data_classif('3gauss', ns + 1) + transp_Xs_new = otda.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) + + # test inverse transform + transp_Xt = otda.inverse_transform(Xt=Xt) + assert_equal(transp_Xt.shape, Xt.shape) + + Xt_new, _ = make_data_classif('3gauss2', nt + 1) + transp_Xt_new = otda.inverse_transform(Xt=Xt_new) + + # check that the oos method is working + assert_equal(transp_Xt_new.shape, Xt_new.shape) + + # test fit_transform + transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) + assert_equal(transp_Xs.shape, Xs.shape) + + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) |