From fa99199c02e497354e34c6ce76e7b4ba15b44d05 Mon Sep 17 00:00:00 2001 From: ievred Date: Fri, 3 Apr 2020 16:06:39 +0200 Subject: v2 laplace emd sinkhorn --- ot/da.py | 485 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 476 insertions(+), 9 deletions(-) (limited to 'ot/da.py') diff --git a/ot/da.py b/ot/da.py index e62e495..39e8c4c 100644 --- a/ot/da.py +++ b/ot/da.py @@ -16,7 +16,7 @@ import scipy.linalg as linalg from .bregman import sinkhorn, jcpot_barycenter from .lp import emd -from .utils import unif, dist, kernel, cost_normalization +from .utils import unif, dist, kernel, cost_normalization, laplacian from .utils import check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg @@ -748,6 +748,233 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, return A, b +def emd_laplace(a, b, xs, xt, M, eta=1., alpha=0.5, + numItermax=1000, stopThr=1e-5, numInnerItermax=1000, + stopInnerThr=1e-6, log=False, verbose=False, **kwargs): + r"""Solve the optimal transport problem (OT) with Laplacian regularization + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + eta\Omega_\alpha(\gamma) + + s.t.\ \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + + where: + + - a and b are source and target weights (sum to 1) + - xs and xt are source and target samples + - M is the (ns,nt) metric cost matrix + - :math:`\Omega_\alpha` is the Laplacian regularization term + :math:`\Omega_\alpha = (1-\alpha)/n_s^2\sum_{i,j}S^s_{i,j}\|T(\mathbf{x}^s_i)-T(\mathbf{x}^s_j)\|^2+\alpha/n_t^2\sum_{i,j}S^t_{i,j}^'\|T(\mathbf{x}^t_i)-T(\mathbf{x}^t_j)\|^2` + with :math:`S^s_{i,j}, S^t_{i,j}` denoting source and target similarity matrices and :math:`T(\cdot)` being a barycentric mapping + + The algorithm used for solving the problem is the conditional gradient algorithm as proposed in [5]. + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + eta : float + Regularization term for Laplacian regularization + alpha : float + Regularization term for source domain's importance in regularization + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (inner emd solver) (>0) + numInnerItermax : int, optional + Max number of iterations (inner CG solver) + stopInnerThr : float, optional + Stop threshold on error (inner CG solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE + Transactions on Pattern Analysis and Machine Intelligence , + vol.PP, no.99, pp.1-1 + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + if 'sim' not in kwargs: + kwargs['sim'] = 'knn' + + if kwargs['sim'] == 'gauss': + if 'rbfparam' not in kwargs: + kwargs['rbfparam'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['rbfparam']) + sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['rbfparam']) + + elif kwargs['sim'] == 'knn': + if 'nn' not in kwargs: + kwargs['nn'] = 5 + + from sklearn.neighbors import kneighbors_graph + + sS = kneighbors_graph(xs, kwargs['nn']).toarray() + sS = (sS + sS.T) / 2 + sT = kneighbors_graph(xt, kwargs['nn']).toarray() + sT = (sT + sT.T) / 2 + + lS = laplacian(sS) + lT = laplacian(sT) + + def f(G): + return alpha*np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \ + + (1-alpha)*np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs))))) + + def df(G): + return alpha*np.dot(lS + lS.T, np.dot(G, np.dot(xt, xt.T)))\ + +(1-alpha)*np.dot(xs, np.dot(xs.T, np.dot(G, lT + lT.T))) + + return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax, + stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log) + +def sinkhorn_laplace(a, b, xs, xt, M, reg=.1, eta=1., alpha=0.5, + numItermax=1000, stopThr=1e-5, numInnerItermax=1000, + stopInnerThr=1e-6, log=False, verbose=False, **kwargs): + r"""Solve the entropic regularized optimal transport problem (OT) with Laplacian regularization + + .. math:: + \gamma = arg\min_\gamma <\gamma,M>_F + reg\Omega_e(\gamma) + eta\Omega_\alpha(\gamma) + + s.t.\ \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + + where: + + - a and b are source and target weights (sum to 1) + - xs and xt are source and target samples + - M is the (ns,nt) metric cost matrix + - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e + (\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\Omega_\alpha` is the Laplacian regularization term + :math:`\Omega_\alpha = (1-\alpha)/n_s^2\sum_{i,j}S^s_{i,j}\|T(\mathbf{x}^s_i)-T(\mathbf{x}^s_j)\|^2+\alpha/n_t^2\sum_{i,j}S^t_{i,j}^'\|T(\mathbf{x}^t_i)-T(\mathbf{x}^t_j)\|^2` + with :math:`S^s_{i,j}, S^t_{i,j}` denoting source and target similarity matrices and :math:`T(\cdot)` being a barycentric mapping + + The algorithm used for solving the problem is the conditional gradient algorithm as proposed in [5]. + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples weights in the target domain + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term for entropic regularization >0 + eta : float + Regularization term for Laplacian regularization + alpha : float + Regularization term for source domain's importance in regularization + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (inner sinkhorn solver) (>0) + numInnerItermax : int, optional + Max number of iterations (inner CG solver) + stopInnerThr : float, optional + Stop threshold on error (inner CG solver) (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE + Transactions on Pattern Analysis and Machine Intelligence , + vol.PP, no.99, pp.1-1 + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + if 'sim' not in kwargs: + kwargs['sim'] = 'knn' + + if kwargs['sim'] == 'gauss': + if 'rbfparam' not in kwargs: + kwargs['rbfparam'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) + sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['rbfparam']) + sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['rbfparam']) + + elif kwargs['sim'] == 'knn': + if 'nn' not in kwargs: + kwargs['nn'] = 5 + + from sklearn.neighbors import kneighbors_graph + + sS = kneighbors_graph(xs, kwargs['nn']).toarray() + sS = (sS + sS.T) / 2 + sT = kneighbors_graph(xt, kwargs['nn']).toarray() + sT = (sT + sT.T) / 2 + + lS = laplacian(sS) + lT = laplacian(sT) + + def f(G): + return alpha*np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \ + + (1-alpha)*np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs))))) + + def df(G): + return alpha*np.dot(lS + lS.T, np.dot(G, np.dot(xt, xt.T)))\ + +(1-alpha)*np.dot(xs, np.dot(xs.T, np.dot(G, lT + lT.T))) + + return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax, stopThr=stopThr, + numInnerItermax=numInnerItermax, stopThr2=stopInnerThr, + verbose=verbose, log=log) + def distribution_estimation_uniform(X): """estimates a uniform distribution from an array of samples X @@ -762,10 +989,12 @@ def distribution_estimation_uniform(X): The uniform distribution estimated from X """ + return unif(X.shape[0]) class BaseTransport(BaseEstimator): + """Base class for OTDA objects Notes @@ -787,6 +1016,7 @@ class BaseTransport(BaseEstimator): inverse_transform method should always get as input a Xt parameter """ + def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples (Xs, ys) and (Xt, yt) @@ -847,6 +1077,7 @@ class BaseTransport(BaseEstimator): return self + def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples (Xs, ys) and (Xt, yt) and transports source samples Xs onto target @@ -875,6 +1106,7 @@ class BaseTransport(BaseEstimator): return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt) + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports source samples Xs onto target ones Xt @@ -942,6 +1174,7 @@ class BaseTransport(BaseEstimator): return transp_Xs + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports target samples Xt onto target samples Xs @@ -1011,6 +1244,7 @@ class BaseTransport(BaseEstimator): class LinearTransport(BaseTransport): + """ OT linear operator between empirical distributions The function estimates the optimal linear operator that aligns the two @@ -1053,14 +1287,15 @@ class LinearTransport(BaseTransport): """ + def __init__(self, reg=1e-8, bias=True, log=False, distribution_estimation=distribution_estimation_uniform): - self.bias = bias self.log = log self.reg = reg self.distribution_estimation = distribution_estimation + def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples (Xs, ys) and (Xt, yt) @@ -1108,6 +1343,7 @@ class LinearTransport(BaseTransport): return self + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports source samples Xs onto target ones Xt @@ -1140,6 +1376,7 @@ class LinearTransport(BaseTransport): return transp_Xs + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports target samples Xt onto target samples Xs @@ -1175,6 +1412,7 @@ class LinearTransport(BaseTransport): class SinkhornTransport(BaseTransport): + """Domain Adapatation OT method based on Sinkhorn Algorithm Parameters @@ -1223,12 +1461,12 @@ class SinkhornTransport(BaseTransport): 26, 2013 """ + def __init__(self, reg_e=1., max_iter=1000, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=np.infty): - self.reg_e = reg_e self.max_iter = max_iter self.tol = tol @@ -1240,6 +1478,7 @@ class SinkhornTransport(BaseTransport): self.distribution_estimation = distribution_estimation self.out_of_sample_map = out_of_sample_map + def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples (Xs, ys) and (Xt, yt) @@ -1284,6 +1523,7 @@ class SinkhornTransport(BaseTransport): class EMDTransport(BaseTransport): + """Domain Adapatation OT method based on Earth Mover's Distance Parameters @@ -1321,11 +1561,11 @@ class EMDTransport(BaseTransport): on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 """ + def __init__(self, metric="sqeuclidean", norm=None, log=False, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10, max_iter=100000): - self.metric = metric self.norm = norm self.log = log @@ -1334,6 +1574,7 @@ class EMDTransport(BaseTransport): self.out_of_sample_map = out_of_sample_map self.max_iter = max_iter + def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples (Xs, ys) and (Xt, yt) @@ -1375,6 +1616,7 @@ class EMDTransport(BaseTransport): class SinkhornLpl1Transport(BaseTransport): + """Domain Adapatation OT method based on sinkhorn algorithm + LpL1 class regularization. @@ -1429,13 +1671,13 @@ class SinkhornLpl1Transport(BaseTransport): """ + def __init__(self, reg_e=1., reg_cl=0.1, max_iter=10, max_inner_iter=200, log=False, tol=10e-9, verbose=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=np.infty): - self.reg_e = reg_e self.reg_cl = reg_cl self.max_iter = max_iter @@ -1449,6 +1691,7 @@ class SinkhornLpl1Transport(BaseTransport): self.out_of_sample_map = out_of_sample_map self.limit_max = limit_max + def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples (Xs, ys) and (Xt, yt) @@ -1493,7 +1736,222 @@ class SinkhornLpl1Transport(BaseTransport): return self +class EMDLaplaceTransport(BaseTransport): + + """Domain Adapatation OT method based on Earth Mover's Distance with Laplacian regularization + + Parameters + ---------- + reg_lap : float, optional (default=1) + Laplacian regularization parameter + reg_src : float, optional (default=0.5) + Source relative importance in regularization + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + max_iter : int, optional (default=100) + Max number of BCD iterations + tol : float, optional (default=1e-5) + Stop threshold on relative loss decrease (>0) + max_inner_iter : int, optional (default=10) + Max number of iterations (inner CG solver) + inner_tol : float, optional (default=1e-6) + Stop threshold on error (inner CG solver) (>0) + log : int, optional (default=False) + Controls the logs of the optimization algorithm + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + + References + ---------- + .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE Transactions + on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + """ + + + def __init__(self, reg_lap = 1., reg_src=1., alpha=0.5, + metric="sqeuclidean", norm=None, max_iter=100, tol=1e-5, + max_inner_iter=100000, inner_tol=1e-6, log=False, verbose=False, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans'): + self.reg_lap = reg_lap + self.reg_src = reg_src + self.alpha = alpha + self.metric = metric + self.norm = norm + self.max_iter = max_iter + self.tol = tol + self.max_inner_iter = max_inner_iter + self.inner_tol = inner_tol + self.log = log + self.verbose = verbose + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt) + + returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_, + xt=self.xt_, M=self.cost_, eta=self.reg_lap, alpha=self.reg_src, + numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, + stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose) + + # coupling estimation + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + return self + +class SinkhornLaplaceTransport(BaseTransport): + + """Domain Adapatation OT method based on entropic regularized OT with Laplacian regularization + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + reg_lap : float, optional (default=1) + Laplacian regularization parameter + reg_src : float, optional (default=0.5) + Source relative importance in regularization + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + max_iter : int, optional (default=100) + Max number of BCD iterations + tol : float, optional (default=1e-5) + Stop threshold on relative loss decrease (>0) + max_inner_iter : int, optional (default=10) + Max number of iterations (inner CG solver) + inner_tol : float, optional (default=1e-6) + Stop threshold on error (inner CG solver) (>0) + log : int, optional (default=False) + Controls the logs of the optimization algorithm + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + + Attributes + ---------- + coupling_ : array-like, shape (n_source_samples, n_target_samples) + The optimal coupling + + References + ---------- + .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, + "Optimal Transport for Domain Adaptation," in IEEE Transactions + on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 + """ + + + def __init__(self, reg_e=1., reg_lap=1., reg_src=0.5, + metric="sqeuclidean", norm=None, max_iter=100, tol=1e-9, + max_inner_iter=200, inner_tol=1e-6, log=False, verbose=False, + distribution_estimation=distribution_estimation_uniform, + out_of_sample_map='ferradans'): + + self.reg_e = reg_e + self.reg_lap = reg_lap + self.reg_src = reg_src + self.metric = metric + self.norm = norm + self.max_iter = max_iter + self.tol = tol + self.max_inner_iter = max_inner_iter + self.inner_tol = inner_tol + self.log = log + self.verbose = verbose + self.distribution_estimation = distribution_estimation + self.out_of_sample_map = out_of_sample_map + + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + super(SinkhornLaplaceTransport, self).fit(Xs, ys, Xt, yt) + + returned_ = sinkhorn_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_, + xt=self.xt_, M=self.cost_, reg=self.reg_e, eta=self.reg_lap, alpha=self.reg_src, + numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, + stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose) + + # coupling estimation + if self.log: + self.coupling_, self.log_ = returned_ + else: + self.coupling_ = returned_ + self.log_ = dict() + return self + + class SinkhornL1l2Transport(BaseTransport): + """Domain Adapatation OT method based on sinkhorn algorithm + l1l2 class regularization. @@ -1550,13 +2008,13 @@ class SinkhornL1l2Transport(BaseTransport): """ + def __init__(self, reg_e=1., reg_cl=0.1, max_iter=10, max_inner_iter=200, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10): - self.reg_e = reg_e self.reg_cl = reg_cl self.max_iter = max_iter @@ -1570,6 +2028,7 @@ class SinkhornL1l2Transport(BaseTransport): self.out_of_sample_map = out_of_sample_map self.limit_max = limit_max + def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples (Xs, ys) and (Xt, yt) @@ -1617,6 +2076,7 @@ class SinkhornL1l2Transport(BaseTransport): class MappingTransport(BaseEstimator): + """MappingTransport: DA methods that aims at jointly estimating a optimal transport coupling and the associated mapping @@ -1673,11 +2133,11 @@ class MappingTransport(BaseEstimator): """ + def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean", norm=None, kernel="linear", sigma=1, max_iter=100, tol=1e-5, max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False, verbose2=False): - self.metric = metric self.norm = norm self.mu = mu @@ -1693,6 +2153,7 @@ class MappingTransport(BaseEstimator): self.verbose = verbose self.verbose2 = verbose2 + def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Builds an optimal coupling and estimates the associated mapping from source and target sets of samples (Xs, ys) and (Xt, yt) @@ -1750,6 +2211,7 @@ class MappingTransport(BaseEstimator): return self + def transform(self, Xs): """Transports source samples Xs onto target ones Xt @@ -1790,6 +2252,7 @@ class MappingTransport(BaseEstimator): class UnbalancedSinkhornTransport(BaseTransport): + """Domain Adapatation unbalanced OT method based on sinkhorn algorithm Parameters @@ -1842,12 +2305,12 @@ class UnbalancedSinkhornTransport(BaseTransport): """ + def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn', max_iter=10, tol=1e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10): - self.reg_e = reg_e self.reg_m = reg_m self.method = method @@ -1861,6 +2324,7 @@ class UnbalancedSinkhornTransport(BaseTransport): self.out_of_sample_map = out_of_sample_map self.limit_max = limit_max + def fit(self, Xs, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples (Xs, ys) and (Xt, yt) @@ -1908,6 +2372,7 @@ class UnbalancedSinkhornTransport(BaseTransport): class JCPOTTransport(BaseTransport): + """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm. Parameters @@ -1954,11 +2419,11 @@ class JCPOTTransport(BaseTransport): """ + def __init__(self, reg_e=.1, max_iter=10, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", out_of_sample_map='ferradans'): - self.reg_e = reg_e self.max_iter = max_iter self.tol = tol @@ -1967,6 +2432,7 @@ class JCPOTTransport(BaseTransport): self.metric = metric self.out_of_sample_map = out_of_sample_map + def fit(self, Xs, ys=None, Xt=None, yt=None): """Building coupling matrices from a list of source and target sets of samples (Xs, ys) and (Xt, yt) @@ -2011,6 +2477,7 @@ class JCPOTTransport(BaseTransport): return self + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports source samples Xs onto target ones Xt -- cgit v1.2.3