diff options
-rw-r--r-- | examples/plot_otda_jcpot.py | 185 | ||||
-rw-r--r-- | ot/da.py | 263 | ||||
-rw-r--r-- | test/test_da.py | 63 |
3 files changed, 404 insertions, 107 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py new file mode 100644 index 0000000..5e5fff8 --- /dev/null +++ b/examples/plot_otda_jcpot.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +""" +======================== +OT for multi-source target shift +======================== + +This example introduces a target shift problem with two 2D source and 1 target domain. + +""" + +# Authors: Remi Flamary <remi.flamary@unice.fr> +# Ievgen Redko <ievgen.redko@univ-st-etienne.fr> +# +# License: MIT License + +import pylab as pl +import numpy as np +import ot + +############################################################################## +# Generate data +# ------------- +n = 50 +sigma = 0.3 +np.random.seed(1985) + + +def get_data(n, p, dec): + y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n)))) + x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + sigma * np.random.randn(len(y), 2) + + x[:, 0] += dec[0] + x[:, 1] += dec[1] + + return x, y + + +p1 = .2 +dec1 = [0, 2] + +p2 = .9 +dec2 = [0, -2] + +pt = .4 +dect = [4, 0] + +xs1, ys1 = get_data(n, p1, dec1) +xs2, ys2 = get_data(n + 1, p2, dec2) +xt, yt = get_data(n, pt, dect) +all_Xr = [xs1, xs2] +all_Yr = [ys1, ys2] +# %% +da = 1.5 + + +def plot_ax(dec, name): + pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5) + pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5) + pl.text(dec[0] - .5, dec[1] + 2, name) + + +############################################################################## +# Fig 1 : plots source and target samples +# --------------------------------------- + +pl.figure(1) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, label='Source 1 (0.8,0.2)') +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, label='Source 2 (0.1,0.9)') +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, label='Target (0.6,0.4)') +pl.title('Data') + +pl.legend() +pl.axis('equal') +pl.axis('off') + + +############################################################################## +# Instantiate Sinkhorn transport algorithm and fit them for all source domains +# ---------------------------------------------------------------------------- +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-2, metric='euclidean') + +M1 = ot.dist(xs1, xt, 'euclidean') +M2 = ot.dist(xs2, xt, 'euclidean') + + +def print_G(G, xs, ys, xt): + for i in range(G.shape[0]): + for j in range(G.shape[1]): + if G[i, j] > 5e-4: + if ys[i]: + c = 'b' + else: + c = 'r' + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2) + + +############################################################################## +# Fig 2 : plot optimal couplings and transported samples +# ------------------------------------------------------ +pl.figure(2) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt) +print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + +pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') +pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + +pl.title('Independent OT') + +pl.legend() +pl.axis('equal') +pl.axis('off') + + +############################################################################## +# Instantiate JCPOT adaptation algorithm and fit it +# ---------------------------------------------------------------------------- +otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, tol=1e-9, verbose=True, log=True) +otda.fit(all_Xr, all_Yr, xt) + +ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2']) +ws2 = otda.proportions_.dot(otda.log_['all_domains'][1]['D2']) + +pl.figure(3) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + +pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') +pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + +pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1])) + +pl.legend() +pl.axis('equal') +pl.axis('off') + +############################################################################## +# Run oracle transport algorithm with known proportions +# ---------------------------------------------------------------------------- + +otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log=True) +otda.fit(all_Xr, all_Yr, xt) + +h_res = np.array([1 - pt, pt]) + +ws1 = h_res.dot(otda.log_['all_domains'][0]['D2']) +ws2 = h_res.dot(otda.log_['all_domains'][1]['D2']) + +pl.figure(4) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + +pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') +pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + +pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1])) + +pl.legend() +pl.axis('equal') +pl.axis('off') +pl.show() @@ -748,79 +748,58 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, stopThr=1e-6, verbose=False, log=False, **kwargs): - """Joint OT and proportion estimation as proposed in [27] + r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] The function solves the following optimization problem: .. math:: - \mathbf{h} = \argmin_{\mathbf{h} \in \Delta_C}\quad \sum_{k=1}^K \lambda_k - W_{reg}\left((\mathbf{D}_2^{(k)} \mathbf{h})^T \mathbf{\delta}_{\mathbf{X}^{(k)}}, \mu\right) + \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k + W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) + s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} - s.t. \gamma^T_k \mathbf{1}_n = \mathbf{1}_n/n - - \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} - - \gamma\geq 0 where : - - M is the (ns,nt) squared euclidean cost matrix between samples in - Xs and Xt (scaled by ns) - - :math:`L` is a ns x d linear operator on a kernel matrix that - approximates the barycentric mapping - - a and b are uniform source and target weights - - The problem consist in solving jointly an optimal transport matrix - :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping - :math:`n_s\gamma X_t`. + - :math:`\lambda_k` is the weight of k-th source domain + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes + - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C + - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)` - One can also estimate a mapping with constant bias (see supplementary - material of [8]) using the bias optional argument. - - The algorithm used for solving the problem is the block coordinate - descent that alternates between updates of G (using conditional gradient) - and the update of L using a classical kernel least square solver. + The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + The algorithm used for solving the problem is the Iterative Bregman projections algorithm + with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution. Parameters ---------- - xs : np.ndarray (ns,d) - samples in the source domain - xt : np.ndarray (nt,d) + Xs : list of K np.ndarray(nsk,d) + features of all source domains' samples + Ys : list of K np.ndarray(nsk,) + labels of all source domains' samples + Xt : np.ndarray (nt,d) samples in the target domain - mu : float,optional - Weight for the linear OT loss (>0) - eta : float, optional - Regularization term for the linear mapping L (>0) - kerneltype : str,optional - kernel used by calling function ot.utils.kernel (gaussian by default) - sigma : float, optional - Gaussian kernel bandwidth. - bias : bool,optional - Estimate linear mapping with constant bias - verbose : bool, optional - Print information along iterations - verbose2 : bool, optional - Print information along iterations + reg : float + Regularization term > 0 + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem numItermax : int, optional - Max number of BCD iterations - numInnerItermax : int, optional - Max number of iterations (inner CG solver) - stopInnerThr : float, optional - Stop threshold on error (inner CG solver) (>0) + Max number of iterations stopThr : float, optional - Stop threshold on relative loss decrease (>0) + Stop threshold on relative change in the barycenter (>0) log : bool, optional record log if True - + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm Returns ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - L : (ns x d) ndarray - Nonlinear mapping matrix (ns+1 x d if bias) + gamma : List of K (nsk x nt) ndarrays + Optimal transportation matrices for the given parameters for each pair of source and target domains + h : (C,) ndarray + proportion estimation in the target domain log : dict log dictionary return only if log==True in parameters @@ -828,62 +807,59 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, References ---------- - .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, - "Mapping estimation for discrete optimal transport", - Neural Information Processing Systems (NIPS), 2016. - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - """ + ''' nbclasses = len(np.unique(Ys[0])) nbdomains = len(Xs) - # we then build, for each source domain, specific information + # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 all_domains = [] + + # log dictionary + if log: + log = {'niter': 0, 'err': [], 'all_domains': []} + for d in range(nbdomains): dom = {} - # get number of elements for this domain - nb_elem = Xs[d].shape[0] - dom['nbelem'] = nb_elem - classes = np.unique(Ys[d]) + nsk = Xs[d].shape[0] # get number of elements for this domain + dom['nbelem'] = nsk + classes = np.unique(Ys[d]) # get number of classes for this domain + # format classes to start from 0 for convenience if np.min(classes) != 0: Ys[d] = Ys[d] - np.min(classes) classes = np.unique(Ys[d]) - # build the corresponding D matrix - D1 = np.zeros((nbclasses, nb_elem)) - D2 = np.zeros((nbclasses, nb_elem)) - classes_d = np.zeros(nbclasses) - - classes_d[np.unique(Ys[d]).astype(int)] = 1 - dom['classes'] = classes_d + # build the corresponding D_1 and D_2 matrices + D1 = np.zeros((nbclasses, nsk)) + D2 = np.zeros((nbclasses, nsk)) for c in classes: nbelemperclass = np.sum(Ys[d] == c) if nbelemperclass != 0: D1[int(c), Ys[d] == c] = 1. - D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) # *nbclasses_d) + D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) dom['D1'] = D1 dom['D2'] = D2 - # build the distance matrix + # build the cost matrix and the Gibbs kernel M = dist(Xs[d], Xt, metric=metric) M = M / np.median(M) - dom['K'] = np.exp(-M/reg) + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + dom['K'] = K all_domains.append(dom) - distribT = unif(np.shape(Xt)[0]) - - if log: - log = {'niter': 0, 'err': []} + # uniform target distribution + a = unif(np.shape(Xt)[0]) - cpt = 0 + cpt = 0 # iterations count err = 1 old_bary = np.ones((nbclasses)) @@ -891,13 +867,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, bary = np.zeros((nbclasses)) + # update coupling matrices for marginal constraints w.r.t. uniform target distribution for d in range(nbdomains): - all_domains[d]['K'] = projC(all_domains[d]['K'], distribT) + all_domains[d]['K'] = projC(all_domains[d]['K'], a) other = np.sum(all_domains[d]['K'], axis=1) bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains bary = np.exp(bary) + # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] for d in range(nbdomains): new = np.dot(all_domains[d]['D2'].T, bary) all_domains[d]['K'] = projR(all_domains[d]['K'], new) @@ -915,12 +893,14 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, print('{:5d}|{:8e}|'.format(cpt, err)) bary = bary / np.sum(bary) + couplings = [all_domains[d]['K'] for d in range(nbdomains)] if log: log['niter'] = cpt - return bary, log + log['all_domains'] = all_domains + return couplings, bary, log else: - return bary + return couplings, bary def distribution_estimation_uniform(X): @@ -2093,9 +2073,10 @@ class UnbalancedSinkhornTransport(BaseTransport): return self + class JCPOTTransport(BaseTransport): - """Domain Adapatation OT method for target shift based on sinkhorn algorithm. + """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm. Parameters ---------- @@ -2104,8 +2085,6 @@ class JCPOTTransport(BaseTransport): max_iter : int, float, optional (default=10) The minimum number of iteration before stopping the optimization algorithm if no it has not converged - max_inner_iter : int, float, optional (default=200) - The number of iteration in the inner loop tol : float, optional (default=10e-9) Stop threshold on error (inner sinkhorn solver) (>0) verbose : bool, optional (default=False) @@ -2126,21 +2105,20 @@ class JCPOTTransport(BaseTransport): Attributes ---------- - coupling_ : array-like, shape (n_source_samples, n_target_samples) - The optimal coupling + coupling_ : list of array-like objects, shape K x (n_source_samples, n_target_samples) + A set of optimal couplings between each source domain and the target domain + proportions_ : array-like, shape (n_classes,) + Estimated class proportions in the target domain log_ : dictionary The dictionary of log, empty dic if parameter log is not True References ---------- - .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, - "Optimal Transport for Domain Adaptation," in IEEE - Transactions on Pattern Analysis and Machine Intelligence , - vol.PP, no.99, pp.1-1 - .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). - Generalized conditional gradient: analysis of convergence - and applications. arXiv preprint arXiv:1510.06567. + .. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), + vol. 89, p.849-858, 2019. """ @@ -2156,20 +2134,18 @@ class JCPOTTransport(BaseTransport): self.verbose = verbose self.log = log self.metric = metric - self.norm = norm - self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + """Building coupling matrices from a list of source and target sets of samples (Xs, ys) and (Xt, yt) Parameters ---------- - Xs : array-like, shape (n_source_samples, n_features) - The training input samples. - ys : array-like, shape (n_source_samples,) - The class labels + Xs : list of K array-like objects, shape K x (nk_source_samples, n_features) + A list of the training input samples. + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) @@ -2188,15 +2164,90 @@ class JCPOTTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): - returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg = self.reg_e, - metric=self.metric, numItermax=self.max_iter, stopThr=self.tol, + self.xs_ = Xs + self.xt_ = Xt + + returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e, + metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol, verbose=self.verbose, log=self.log) # deal with the value of log if self.log: - self.coupling_, self.log_ = returned_ + self.coupling_, self.proportions_, self.log_ = returned_ else: - self.coupling_ = returned_ + self.coupling_, self.proportions_ = returned_ self.log_ = dict() return self + + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + """ + + transp_Xs = [] + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]): + + # perform standard barycentric mapping for each source domain + + for coupling in self.coupling_: + transp = coupling / np.sum(coupling, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute transported samples + transp_Xs.append(np.dot(transp, self.xt_)) + else: + + # perform out of sample mapping + indices = np.arange(Xs.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] + + transp_Xs = [] + + for bi in batch_ind: + transp_Xs_ = [] + + # get the nearest neighbor in the sources domains + xs = np.concatenate(self.xs_, axis=0) + idx = np.argmin(dist(Xs[bi], xs), axis=1) + + # transport the source samples + for coupling in self.coupling_: + transp = coupling / np.sum( + coupling, 1)[:, None] + transp[~ np.isfinite(transp)] = 0 + transp_Xs_.append(np.dot(transp, self.xt_)) + + transp_Xs_ = np.concatenate(transp_Xs_, axis=0) + + # define the transported points + transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :] + transp_Xs.append(transp_Xs_) + + transp_Xs = np.concatenate(transp_Xs, axis=0) + + return transp_Xs diff --git a/test/test_da.py b/test/test_da.py index 2a5e50e..a8c258a 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -5,7 +5,7 @@ # License: MIT License import numpy as np -from numpy.testing.utils import assert_allclose, assert_equal +from numpy.testing import assert_allclose, assert_equal import ot from ot.datasets import make_data_classif @@ -549,3 +549,64 @@ def test_linear_mapping_class(): Cst = np.cov(Xst.T) np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +def test_jcpot_transport_class(): + """test_jcpot_transport + """ + + ns1 = 150 + ns2 = 150 + nt = 200 + + Xs1, ys1 = make_data_classif('3gauss', ns1) + Xs2, ys2 = make_data_classif('3gauss', ns2) + + Xt, yt = make_data_classif('3gauss2', nt) + + Xs = [Xs1, Xs2] + ys = [ys1, ys2] + + otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True) + + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + print(otda.proportions_) + + assert hasattr(otda, "coupling_") + assert hasattr(otda, "proportions_") + assert hasattr(otda, "log_") + + # test dimensions of coupling + for i, xs in enumerate(Xs): + assert_equal(otda.coupling_[i].shape, ((xs.shape[0], Xt.shape[0]))) + + # test all margin constraints + mu_t = unif(nt) + + for i in range(len(Xs)): + # test margin constraints w.r.t. uniform target weights for each coupling matrix + assert_allclose( + np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3) + + # test margin constraints w.r.t. modified source weights for each source domain + + D1 = np.zeros((len(np.unique(ys[i])), len(ys[i]))) + for c in np.unique(ys[i]): + nbelemperclass = np.sum(ys[i] == c) + if nbelemperclass != 0: + D1[int(c), ys[i] == c] = 1. + + assert_allclose( + np.dot(D1, np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) + + # test transform + transp_Xs = otda.transform(Xs=Xs) + [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] + #assert_equal(transp_Xs.shape, Xs.shape) + + Xs_new, _ = make_data_classif('3gauss', ns1 + 1) + transp_Xs_new = otda.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) |