diff options
author | Ievgen Redko <ievgen.redko@univ-st-etienne.fr> | 2020-04-15 17:02:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-15 17:02:18 +0200 |
commit | f9499dce4e14c39bd42efeee22ad861933326a8b (patch) | |
tree | 73c0bc7021b945d43b667cea76e4424e5f6e1147 /ot/da.py | |
parent | 2571a3ead11a7fc010ed20b1af6faeef464565a1 (diff) | |
parent | adc5570550676b63b9aabb2205a67c5b7c9187f3 (diff) |
Merge branch 'master' into laplace_da
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 343 |
1 files changed, 342 insertions, 1 deletions
@@ -14,7 +14,7 @@ Domain adaptation with optimal transport import numpy as np import scipy.linalg as linalg -from .bregman import sinkhorn +from .bregman import sinkhorn, jcpot_barycenter from .lp import emd from .utils import unif, dist, kernel, cost_normalization, label_normalization from .utils import check_params, BaseEstimator @@ -895,6 +895,9 @@ class BaseTransport(BaseEstimator): transform method should always get as input a Xs parameter inverse_transform method should always get as input a Xt parameter + + transform_labels method should always get as input a ys parameter + inverse_transform_labels method should always get as input a yt parameter """ def fit(self, Xs=None, ys=None, Xt=None, yt=None): @@ -1052,6 +1055,50 @@ class BaseTransport(BaseEstimator): return transp_Xs + def transform_labels(self, ys=None): + """Propagate source labels ys to obtain estimated target labels as in [27] + + Parameters + ---------- + ys : array-like, shape (n_source_samples,) + The class labels + + Returns + ------- + transp_ys : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. + + References + ---------- + + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. + + """ + + # check the necessary inputs parameters are here + if check_params(ys=ys): + + ysTemp = label_normalization(np.copy(ys)) + classes = np.unique(ysTemp) + n = len(classes) + D1 = np.zeros((n, len(ysTemp))) + + # perform label propagation + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + for c in classes: + D1[int(c), ysTemp == c] = 1 + + # compute propagated labels + transp_ys = np.dot(D1, transp) + + return transp_ys.T + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports target samples Xt onto target samples Xs @@ -1119,6 +1166,41 @@ class BaseTransport(BaseEstimator): return transp_Xt + def inverse_transform_labels(self, yt=None): + """Propagate target labels yt to obtain estimated source labels ys + + Parameters + ---------- + yt : array-like, shape (n_target_samples,) + + Returns + ------- + transp_ys : array-like, shape (n_source_samples, nb_classes) + Estimated soft source labels. + """ + + # check the necessary inputs parameters are here + if check_params(yt=yt): + + ytTemp = label_normalization(np.copy(yt)) + classes = np.unique(ytTemp) + n = len(classes) + D1 = np.zeros((n, len(ytTemp))) + + # perform label propagation + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + for c in classes: + D1[int(c), ytTemp == c] = 1 + + # compute propagated samples + transp_ys = np.dot(D1, transp.T) + + return transp_ys.T + class LinearTransport(BaseTransport): @@ -2122,3 +2204,262 @@ class UnbalancedSinkhornTransport(BaseTransport): self.log_ = dict() return self + + +class JCPOTTransport(BaseTransport): + + """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm. + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + max_iter : int, float, optional (default=10) + The minimum number of iteration before stopping the optimization + algorithm if no it has not converged + tol : float, optional (default=10e-9) + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + log : bool, optional (default=False) + Controls the logs of the optimization algorithm + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + + Attributes + ---------- + coupling_ : list of array-like objects, shape K x (n_source_samples, n_target_samples) + A set of optimal couplings between each source domain and the target domain + proportions_ : array-like, shape (n_classes,) + Estimated class proportions in the target domain + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True + + References + ---------- + + .. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), + vol. 89, p.849-858, 2019. + + """ + + def __init__(self, reg_e=.1, max_iter=10, + tol=10e-9, verbose=False, log=False, + metric="sqeuclidean", + out_of_sample_map='ferradans'): + self.reg_e = reg_e + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.log = log + self.metric = metric + self.out_of_sample_map = out_of_sample_map + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Building coupling matrices from a list of source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : list of K array-like objects, shape K x (nk_source_samples, n_features) + A list of the training input samples. + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt, ys=ys): + + self.xs_ = Xs + self.xt_ = Xt + + returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e, + metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol, + verbose=self.verbose, log=True) + + self.coupling_ = returned_[1]['gamma'] + + # deal with the value of log + if self.log: + self.proportions_, self.log_ = returned_ + else: + self.proportions_ = returned_ + self.log_ = dict() + + return self + + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : list of K array-like objects, shape K x (nk_source_samples, n_features) + A list of the training input samples. + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + """ + + transp_Xs = [] + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]): + + # perform standard barycentric mapping for each source domain + + for coupling in self.coupling_: + transp = coupling / np.sum(coupling, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute transported samples + transp_Xs.append(np.dot(transp, self.xt_)) + else: + + # perform out of sample mapping + indices = np.arange(Xs.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] + + transp_Xs = [] + + for bi in batch_ind: + transp_Xs_ = [] + + # get the nearest neighbor in the sources domains + xs = np.concatenate(self.xs_, axis=0) + idx = np.argmin(dist(Xs[bi], xs), axis=1) + + # transport the source samples + for coupling in self.coupling_: + transp = coupling / np.sum( + coupling, 1)[:, None] + transp[~ np.isfinite(transp)] = 0 + transp_Xs_.append(np.dot(transp, self.xt_)) + + transp_Xs_ = np.concatenate(transp_Xs_, axis=0) + + # define the transported points + transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :] + transp_Xs.append(transp_Xs_) + + transp_Xs = np.concatenate(transp_Xs, axis=0) + + return transp_Xs + + def transform_labels(self, ys=None): + """Propagate source labels ys to obtain target labels as in [27] + + Parameters + ---------- + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels + + Returns + ------- + yt : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. + """ + + # check the necessary inputs parameters are here + if check_params(ys=ys): + yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0])) + for i in range(len(ys)): + ysTemp = label_normalization(np.copy(ys[i])) + classes = np.unique(ysTemp) + n = len(classes) + ns = len(ysTemp) + + # perform label propagation + transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + if self.log: + D1 = self.log_['D1'][i] + else: + D1 = np.zeros((n, ns)) + + for c in classes: + D1[int(c), ysTemp == c] = 1 + + # compute propagated labels + yt = yt + np.dot(D1, transp) / len(ys) + + return yt.T + + def inverse_transform_labels(self, yt=None): + """Propagate source labels ys to obtain target labels + + Parameters + ---------- + yt : array-like, shape (n_source_samples,) + The target class labels + + Returns + ------- + transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes) + A list of estimated soft source labels + """ + + # check the necessary inputs parameters are here + if check_params(yt=yt): + transp_ys = [] + ytTemp = label_normalization(np.copy(yt)) + classes = np.unique(ytTemp) + n = len(classes) + D1 = np.zeros((n, len(ytTemp))) + + for c in classes: + D1[int(c), ytTemp == c] = 1 + + for i in range(len(self.xs_)): + + # perform label propagation + transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute propagated labels + transp_ys.append(np.dot(D1, transp.T).T) + + return transp_ys |