From 767171593f2a98a26b9a39bf110a45085e3b982e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Thu, 24 Mar 2022 10:53:47 +0100 Subject: [MRG] Domain adaptation and unbalanced solvers with backend support (#343) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: RĂ©mi Flamary --- ot/da.py | 382 +++++++++++++++++++++++++++++++++++---------------------------- 1 file changed, 215 insertions(+), 167 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index 841f31a..0b9737e 100644 --- a/ot/da.py +++ b/ot/da.py @@ -12,12 +12,12 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np -import scipy.linalg as linalg +from .backend import get_backend from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots -from .utils import check_params, BaseEstimator +from .utils import list_to_array, check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg from .optim import gcg @@ -60,13 +60,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - labels_a : np.ndarray (ns,) + labels_a : array-like (ns,) labels of samples in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples weights in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix reg : float Regularization term for entropic regularization >0 @@ -86,7 +86,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -111,26 +111,28 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, ot.optim.cg : General regularized OT """ + a, labels_a, b, M = list_to_array(a, labels_a, b, M) + nx = get_backend(a, labels_a, b, M) + p = 0.5 epsilon = 1e-3 indices_labels = [] - classes = np.unique(labels_a) + classes = nx.unique(labels_a) for c in classes: - idxc, = np.where(labels_a == c) + idxc, = nx.where(labels_a == c) indices_labels.append(idxc) - W = np.zeros(M.shape) - + W = nx.zeros(M.shape, type_as=M) for cpt in range(numItermax): Mreg = M + eta * W transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, stopThr=stopInnerThr) # the transport has been computed. Check if classes are really # separated - W = np.ones(M.shape) + W = nx.ones(M.shape, type_as=M) for (i, c) in enumerate(classes): - majs = np.sum(transp[indices_labels[i]], axis=0) + majs = nx.sum(transp[indices_labels[i]], axis=0) majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs @@ -174,13 +176,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - labels_a : np.ndarray (ns,) + labels_a : array-like (ns,) labels of samples in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix reg : float Regularization term for entropic regularization >0 @@ -200,7 +202,7 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -222,22 +224,25 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, ot.optim.gcg : Generalized conditional gradient for OT problems """ - lstlab = np.unique(labels_a) + a, labels_a, b, M = list_to_array(a, labels_a, b, M) + nx = get_backend(a, labels_a, b, M) + + lstlab = nx.unique(labels_a) def f(G): res = 0 for i in range(G.shape[1]): for lab in lstlab: temp = G[labels_a == lab, i] - res += np.linalg.norm(temp) + res += nx.norm(temp) return res def df(G): - W = np.zeros(G.shape) + W = nx.zeros(G.shape, type_as=G) for i in range(G.shape[1]): for lab in lstlab: temp = G[labels_a == lab, i] - n = np.linalg.norm(temp) + n = nx.norm(temp) if n: W[labels_a == lab, i] = temp / n return W @@ -289,9 +294,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain mu : float,optional Weight for the linear OT loss (>0) @@ -315,9 +320,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters - L : (d, d) ndarray + L : (d, d) array-like Linear mapping matrix ((:math:`d+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters @@ -336,13 +341,15 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, ot.optim.cg : General regularized OT """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1] if bias: - xs1 = np.hstack((xs, np.ones((ns, 1)))) - xstxs = xs1.T.dot(xs1) - Id = np.eye(d + 1) + xs1 = nx.concatenate((xs, nx.ones((ns, 1), type_as=xs)), axis=1) + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d + 1, type_as=xs) Id[-1] = 0 I0 = Id[:, :-1] @@ -350,8 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, return x[:-1, :] else: xs1 = xs - xstxs = xs1.T.dot(xs1) - Id = np.eye(d) + xstxs = nx.dot(xs1.T, xs1) + Id = nx.eye(d, type_as=xs) I0 = Id def sel(x): @@ -360,7 +367,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, if log: log = {'err': []} - a, b = unif(ns), unif(nt) + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) M = dist(xs, xt) * ns G = emd(a, b, M) @@ -368,23 +376,26 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, def loss(L, G): """Compute full loss""" - return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ - np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2) + return ( + nx.sum((nx.dot(xs1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.sum(sel(L - I0) ** 2) + ) def solve_L(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0) + xst = ns * nx.dot(G, xt) + return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0) def solve_G(L, G0): """Update G with CG algorithm""" - xsi = xs1.dot(L) + xsi = nx.dot(xs1, L) def f(G): - return np.sum((xsi - ns * G.dot(xt)) ** 2) + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) def df(G): - return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) @@ -481,9 +492,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain mu : float,optional Weight for the linear OT loss (>0) @@ -513,9 +524,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters - L : (ns, d) ndarray + L : (ns, d) array-like Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias) log : dict log dictionary return only if log==True in parameters @@ -534,15 +545,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', ot.optim.cg : General regularized OT """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) ns, nt = xs.shape[0], xt.shape[0] K = kernel(xs, xs, method=kerneltype, sigma=sigma) if bias: - K1 = np.hstack((K, np.ones((ns, 1)))) - Id = np.eye(ns + 1) + K1 = nx.concatenate((K, nx.ones((ns, 1), type_as=xs)), axis=1) + Id = nx.eye(ns + 1, type_as=xs) Id[-1] = 0 - Kp = np.eye(ns + 1) + Kp = nx.eye(ns + 1, type_as=xs) Kp[:ns, :ns] = K # ls regu @@ -550,12 +563,12 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', # Kreg=I # RKHS regul - K0 = K1.T.dot(K1) + eta * Kp + K0 = nx.dot(K1.T, K1) + eta * Kp Kreg = Kp else: K1 = K - Id = np.eye(ns) + Id = nx.eye(ns, type_as=xs) # ls regul # K0 = K1.T.dot(K1)+eta*I @@ -568,7 +581,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', if log: log = {'err': []} - a, b = unif(ns), unif(nt) + a = unif(ns, type_as=xs) + b = unif(nt, type_as=xt) M = dist(xs, xt) * ns G = emd(a, b, M) @@ -576,28 +590,31 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def loss(L, G): """Compute full loss""" - return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ - np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) + return ( + nx.sum((nx.dot(K1, L) - ns * nx.dot(G, xt)) ** 2) + + mu * nx.sum(G * M) + + eta * nx.trace(dots(L.T, Kreg, L)) + ) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(K0, xst) + xst = ns * nx.dot(G, xt) + return nx.solve(K0, xst) def solve_L_bias(G): """ solve L problem with fixed G (least square)""" - xst = ns * G.dot(xt) - return np.linalg.solve(K0, K1.T.dot(xst)) + xst = ns * nx.dot(G, xt) + return nx.solve(K0, nx.dot(K1.T, xst)) def solve_G(L, G0): """Update G with CG algorithm""" - xsi = K1.dot(L) + xsi = nx.dot(K1, L) def f(G): - return np.sum((xsi - ns * G.dot(xt)) ** 2) + return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2) def df(G): - return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T) G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) @@ -681,15 +698,15 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, Parameters ---------- - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain reg : float,optional regularization added to the diagonals of covariances (>0) - ws : np.ndarray (ns,1), optional + ws : array-like (ns,1), optional weights for the source samples - wt : np.ndarray (ns,1), optional + wt : array-like (ns,1), optional weights for the target samples bias: boolean, optional estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) @@ -699,9 +716,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, Returns ------- - A : (d, d) ndarray + A : (d, d) array-like Linear operator - b : (1, d) ndarray + b : (1, d) array-like bias log : dict log dictionary return only if log==True in parameters @@ -719,36 +736,38 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, """ + xs, xt = list_to_array(xs, xt) + nx = get_backend(xs, xt) d = xs.shape[1] if bias: - mxs = xs.mean(0, keepdims=True) - mxt = xt.mean(0, keepdims=True) + mxs = nx.mean(xs, axis=0)[None, :] + mxt = nx.mean(xt, axis=0)[None, :] xs = xs - mxs xt = xt - mxt else: - mxs = np.zeros((1, d)) - mxt = np.zeros((1, d)) + mxs = nx.zeros((1, d), type_as=xs) + mxt = nx.zeros((1, d), type_as=xs) if ws is None: - ws = np.ones((xs.shape[0], 1)) / xs.shape[0] + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] if wt is None: - wt = np.ones((xt.shape[0], 1)) / xt.shape[0] + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] - Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d) - Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d) + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt) - Cs12 = linalg.sqrtm(Cs) - Cs_12 = linalg.inv(Cs12) + Cs12 = nx.sqrtm(Cs) + Cs_12 = nx.inv(Cs12) - M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + M0 = nx.sqrtm(dots(Cs12, Ct, Cs12)) - A = Cs_12.dot(M0.dot(Cs_12)) + A = dots(Cs_12, M0, Cs_12) - b = mxt - mxs.dot(A) + b = mxt - nx.dot(mxs, A) if log: log = {} @@ -798,15 +817,15 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Parameters ---------- - a : np.ndarray (ns,) + a : array-like (ns,) samples weights in the source domain - b : np.ndarray (nt,) + b : array-like (nt,) samples weights in the target domain - xs : np.ndarray (ns,d) + xs : array-like (ns,d) samples in the source domain - xt : np.ndarray (nt,d) + xt : array-like (nt,d) samples in the target domain - M : np.ndarray (ns,nt) + M : array-like (ns,nt) loss matrix sim : string, optional Type of similarity ('knn' or 'gauss') used to construct the Laplacian. @@ -834,7 +853,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al Returns ------- - gamma : (ns, nt) ndarray + gamma : (ns, nt) array-like Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -862,9 +881,12 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al raise ValueError( 'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__)) + a, b, xs, xt, M = list_to_array(a, b, xs, xt, M) + nx = get_backend(a, b, xs, xt, M) + if sim == 'gauss': if sim_param is None: - sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sim_param = 1 / (2 * (nx.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) sS = kernel(xs, xs, method=sim, sigma=sim_param) sT = kernel(xt, xt, method=sim, sigma=sim_param) @@ -874,9 +896,13 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al from sklearn.neighbors import kneighbors_graph - sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray() + sS = nx.from_numpy(kneighbors_graph( + X=nx.to_numpy(xs), n_neighbors=int(sim_param) + ).toarray(), type_as=xs) sS = (sS + sS.T) / 2 - sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray() + sT = nx.from_numpy(kneighbors_graph( + X=nx.to_numpy(xt), n_neighbors=int(sim_param) + ).toarray(), type_as=xt) sT = (sT + sT.T) / 2 else: raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim)) @@ -885,12 +911,14 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al lT = laplacian(sT) def f(G): - return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \ - + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs))))) + return ( + alpha * nx.trace(dots(xt.T, G.T, lS, G, xt)) + + (1 - alpha) * nx.trace(dots(xs.T, G, lT, G.T, xs)) + ) ls2 = lS + lS.T lt2 = lT + lT.T - xt2 = np.dot(xt, xt.T) + xt2 = nx.dot(xt, xt.T) if reg == 'disp': Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T) @@ -898,8 +926,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al M = M + Cs + Ct def df(G): - return alpha * np.dot(ls2, np.dot(G, xt2))\ - + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2))) + return ( + alpha * dots(ls2, G, xt2) + + (1 - alpha) * dots(xs, xs.T, G, lt2) + ) return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax, stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log) @@ -919,7 +949,7 @@ def distribution_estimation_uniform(X): The uniform distribution estimated from :math:`\mathbf{X}` """ - return unif(X.shape[0]) + return unif(X.shape[0], type_as=X) class BaseTransport(BaseEstimator): @@ -973,6 +1003,7 @@ class BaseTransport(BaseEstimator): self : object Returns self. """ + nx = self._get_backend(Xs, ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): @@ -984,14 +1015,14 @@ class BaseTransport(BaseEstimator): if (ys is not None) and (yt is not None): if self.limit_max != np.infty: - self.limit_max = self.limit_max * np.max(self.cost_) + self.limit_max = self.limit_max * nx.max(self.cost_) # assumes labeled source samples occupy the first rows # and labeled target samples occupy the first columns - classes = [c for c in np.unique(ys) if c != -1] + classes = [c for c in nx.unique(ys) if c != -1] for c in classes: - idx_s = np.where((ys != c) & (ys != -1)) - idx_t = np.where(yt == c) + idx_s = nx.where((ys != c) & (ys != -1)) + idx_t = nx.where(yt == c) # all the coefficients corresponding to a source sample # and a target sample : @@ -1062,23 +1093,24 @@ class BaseTransport(BaseEstimator): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - if np.array_equal(self.xs_, Xs): + if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, self.xt_) + transp_Xs = nx.dot(transp, self.xt_) else: # perform out of sample mapping - indices = np.arange(Xs.shape[0]) + indices = nx.arange(Xs.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -1087,20 +1119,20 @@ class BaseTransport(BaseEstimator): for bi in batch_ind: # get the nearest neighbor in the source domain D0 = dist(Xs[bi], self.xs_) - idx = np.argmin(D0, axis=1) + idx = nx.argmin(D0, axis=1) # transport the source samples - transp = self.coupling_ / np.sum( - self.coupling_, 1)[:, None] - transp[~ np.isfinite(transp)] = 0 - transp_Xs_ = np.dot(transp, self.xt_) + transp = self.coupling_ / nx.sum( + self.coupling_, axis=1)[:, None] + transp[~ nx.isfinite(transp)] = 0 + transp_Xs_ = nx.dot(transp, self.xt_) # define the transported points transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :] transp_Xs.append(transp_Xs_) - transp_Xs = np.concatenate(transp_Xs, axis=0) + transp_Xs = nx.concatenate(transp_Xs, axis=0) return transp_Xs @@ -1127,26 +1159,27 @@ class BaseTransport(BaseEstimator): International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(ys=ys): - ysTemp = label_normalization(np.copy(ys)) - classes = np.unique(ysTemp) + ysTemp = label_normalization(nx.copy(ys)) + classes = nx.unique(ysTemp) n = len(classes) - D1 = np.zeros((n, len(ysTemp))) + D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True) + transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 for c in classes: D1[int(c), ysTemp == c] = 1 # compute propagated labels - transp_ys = np.dot(D1, transp) + transp_ys = nx.dot(D1, transp) return transp_ys.T @@ -1176,23 +1209,24 @@ class BaseTransport(BaseEstimator): transp_Xt : array-like, shape (n_source_samples, n_features) The transported target samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xt=Xt): - if np.array_equal(self.xt_, Xt): + if nx.array_equal(self.xt_, Xt): # perform standard barycentric mapping - transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None] + transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None] # set nans to 0 - transp_[~ np.isfinite(transp_)] = 0 + transp_[~ nx.isfinite(transp_)] = 0 # compute transported samples - transp_Xt = np.dot(transp_, self.xs_) + transp_Xt = nx.dot(transp_, self.xs_) else: # perform out of sample mapping - indices = np.arange(Xt.shape[0]) + indices = nx.arange(Xt.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -1200,20 +1234,20 @@ class BaseTransport(BaseEstimator): transp_Xt = [] for bi in batch_ind: D0 = dist(Xt[bi], self.xt_) - idx = np.argmin(D0, axis=1) + idx = nx.argmin(D0, axis=1) # transport the target samples - transp_ = self.coupling_.T / np.sum( + transp_ = self.coupling_.T / nx.sum( self.coupling_, 0)[:, None] - transp_[~ np.isfinite(transp_)] = 0 - transp_Xt_ = np.dot(transp_, self.xs_) + transp_[~ nx.isfinite(transp_)] = 0 + transp_Xt_ = nx.dot(transp_, self.xs_) # define the transported points transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :] transp_Xt.append(transp_Xt_) - transp_Xt = np.concatenate(transp_Xt, axis=0) + transp_Xt = nx.concatenate(transp_Xt, axis=0) return transp_Xt @@ -1230,26 +1264,27 @@ class BaseTransport(BaseEstimator): transp_ys : array-like, shape (n_source_samples, nb_classes) Estimated soft source labels. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(yt=yt): - ytTemp = label_normalization(np.copy(yt)) - classes = np.unique(ytTemp) + ytTemp = label_normalization(nx.copy(yt)) + classes = nx.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(ytTemp))) + D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_) # perform label propagation - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 for c in classes: D1[int(c), ytTemp == c] = 1 # compute propagated samples - transp_ys = np.dot(D1, transp.T) + transp_ys = nx.dot(D1, transp.T) return transp_ys.T @@ -1330,14 +1365,15 @@ class LinearTransport(BaseTransport): self : object Returns self. """ + nx = self._get_backend(Xs, ys, Xt, yt) self.mu_s = self.distribution_estimation(Xs) self.mu_t = self.distribution_estimation(Xt) # coupling estimation returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, - ws=self.mu_s.reshape((-1, 1)), - wt=self.mu_t.reshape((-1, 1)), + ws=nx.reshape(self.mu_s, (-1, 1)), + wt=nx.reshape(self.mu_t, (-1, 1)), bias=self.bias, log=self.log) # deal with the value of log @@ -1348,8 +1384,8 @@ class LinearTransport(BaseTransport): self.log_ = dict() # re compute inverse mapping - self.A1_ = linalg.inv(self.A_) - self.B1_ = -self.B_.dot(self.A1_) + self.A1_ = nx.inv(self.A_) + self.B1_ = -nx.dot(self.B_, self.A1_) return self @@ -1378,10 +1414,11 @@ class LinearTransport(BaseTransport): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - transp_Xs = Xs.dot(self.A_) + self.B_ + transp_Xs = nx.dot(Xs, self.A_) + self.B_ return transp_Xs @@ -1411,10 +1448,11 @@ class LinearTransport(BaseTransport): transp_Xt : array-like, shape (n_source_samples, n_features) The transported target samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xt=Xt): - transp_Xt = Xt.dot(self.A1_) + self.B1_ + transp_Xt = nx.dot(Xt, self.A1_) + self.B1_ return transp_Xt @@ -2112,6 +2150,7 @@ class MappingTransport(BaseEstimator): self : object Returns self """ + self._get_backend(Xs, ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt): @@ -2158,19 +2197,20 @@ class MappingTransport(BaseEstimator): transp_Xs : array-like, shape (n_source_samples, n_features) The transport source samples. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(Xs=Xs): - if np.array_equal(self.xs_, Xs): + if nx.array_equal(self.xs_, Xs): # perform standard barycentric mapping - transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs = np.dot(transp, self.xt_) + transp_Xs = nx.dot(transp, self.xt_) else: if self.kernel == "gaussian": K = kernel(Xs, self.xs_, method=self.kernel, @@ -2178,8 +2218,10 @@ class MappingTransport(BaseEstimator): elif self.kernel == "linear": K = Xs if self.bias: - K = np.hstack((K, np.ones((Xs.shape[0], 1)))) - transp_Xs = K.dot(self.mapping_) + K = nx.concatenate( + [K, nx.ones((Xs.shape[0], 1), type_as=K)], axis=1 + ) + transp_Xs = nx.dot(K, self.mapping_) return transp_Xs @@ -2396,6 +2438,7 @@ class JCPOTTransport(BaseTransport): self : object Returns self. """ + self._get_backend(*Xs, *ys, Xt, yt) # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): @@ -2438,28 +2481,29 @@ class JCPOTTransport(BaseTransport): batch_size : int, optional (default=128) The batch size for out of sample inverse transform """ + nx = self.nx transp_Xs = [] # check the necessary inputs parameters are here if check_params(Xs=Xs): - if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]): + if all([nx.allclose(x, y) for x, y in zip(self.xs_, Xs)]): # perform standard barycentric mapping for each source domain for coupling in self.coupling_: - transp = coupling / np.sum(coupling, 1)[:, None] + transp = coupling / nx.sum(coupling, 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute transported samples - transp_Xs.append(np.dot(transp, self.xt_)) + transp_Xs.append(nx.dot(transp, self.xt_)) else: # perform out of sample mapping - indices = np.arange(Xs.shape[0]) + indices = nx.arange(Xs.shape[0]) batch_ind = [ indices[i:i + batch_size] for i in range(0, len(indices), batch_size)] @@ -2470,23 +2514,22 @@ class JCPOTTransport(BaseTransport): transp_Xs_ = [] # get the nearest neighbor in the sources domains - xs = np.concatenate(self.xs_, axis=0) - idx = np.argmin(dist(Xs[bi], xs), axis=1) + xs = nx.concatenate(self.xs_, axis=0) + idx = nx.argmin(dist(Xs[bi], xs), axis=1) # transport the source samples for coupling in self.coupling_: - transp = coupling / np.sum( - coupling, 1)[:, None] - transp[~ np.isfinite(transp)] = 0 - transp_Xs_.append(np.dot(transp, self.xt_)) + transp = coupling / nx.sum(coupling, 1)[:, None] + transp[~ nx.isfinite(transp)] = 0 + transp_Xs_.append(nx.dot(transp, self.xt_)) - transp_Xs_ = np.concatenate(transp_Xs_, axis=0) + transp_Xs_ = nx.concatenate(transp_Xs_, axis=0) # define the transported points transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :] transp_Xs.append(transp_Xs_) - transp_Xs = np.concatenate(transp_Xs, axis=0) + transp_Xs = nx.concatenate(transp_Xs, axis=0) return transp_Xs @@ -2512,32 +2555,36 @@ class JCPOTTransport(BaseTransport): "Optimal transport for multi-source domain adaptation under target shift", International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. """ + nx = self.nx # check the necessary inputs parameters are here if check_params(ys=ys): - yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0])) + yt = nx.zeros( + (len(nx.unique(nx.concatenate(ys))), self.xt_.shape[0]), + type_as=ys[0] + ) for i in range(len(ys)): - ysTemp = label_normalization(np.copy(ys[i])) - classes = np.unique(ysTemp) + ysTemp = label_normalization(nx.copy(ys[i])) + classes = nx.unique(ysTemp) n = len(classes) ns = len(ysTemp) # perform label propagation - transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 if self.log: D1 = self.log_['D1'][i] else: - D1 = np.zeros((n, ns)) + D1 = nx.zeros((n, ns), type_as=transp) for c in classes: D1[int(c), ysTemp == c] = 1 # compute propagated labels - yt = yt + np.dot(D1, transp) / len(ys) + yt = yt + nx.dot(D1, transp) / len(ys) return yt.T @@ -2555,14 +2602,15 @@ class JCPOTTransport(BaseTransport): transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes) A list of estimated soft source labels """ + nx = self.nx # check the necessary inputs parameters are here if check_params(yt=yt): transp_ys = [] - ytTemp = label_normalization(np.copy(yt)) - classes = np.unique(ytTemp) + ytTemp = label_normalization(nx.copy(yt)) + classes = nx.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(ytTemp))) + D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0]) for c in classes: D1[int(c), ytTemp == c] = 1 @@ -2570,12 +2618,12 @@ class JCPOTTransport(BaseTransport): for i in range(len(self.xs_)): # perform label propagation - transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None] # set nans to 0 - transp[~ np.isfinite(transp)] = 0 + transp[~ nx.isfinite(transp)] = 0 # compute propagated labels - transp_ys.append(np.dot(D1, transp.T).T) + transp_ys.append(nx.dot(D1, transp.T).T) return transp_ys -- cgit v1.2.3