diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-08 16:34:39 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-08 16:34:39 +0200 |
commit | 1a4c264cc9b2cb0bb89840ee9175177e86eef3ef (patch) | |
tree | ed3835181028050245d555548fcac6714122ae1d /ot/da.py | |
parent | 0b402fd7c7e07043afd3a9df9d75bc424730b06f (diff) |
added label normalization to utils
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 75 |
1 files changed, 37 insertions, 38 deletions
@@ -16,7 +16,7 @@ import scipy.linalg as linalg from .bregman import sinkhorn, jcpot_barycenter from .lp import emd -from .utils import unif, dist, kernel, cost_normalization +from .utils import unif, dist, kernel, cost_normalization, label_normalization from .utils import check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg @@ -786,6 +786,9 @@ class BaseTransport(BaseEstimator): transform method should always get as input a Xs parameter inverse_transform method should always get as input a Xt parameter + + transform_labels method should always get as input a ys parameter + inverse_transform_labels method should always get as input a yt parameter """ def fit(self, Xs=None, ys=None, Xt=None, yt=None): @@ -944,7 +947,7 @@ class BaseTransport(BaseEstimator): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels ys to obtain estimated target labels + """Propagate source labels ys to obtain estimated target labels as in [27] Parameters ---------- @@ -955,14 +958,23 @@ class BaseTransport(BaseEstimator): ------- transp_ys : array-like, shape (n_target_samples,) Estimated target labels. + + References + ---------- + + .. [27] Ievgen Redko, Nicolas Courty, RĂ©mi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. + """ # check the necessary inputs parameters are here if check_params(ys=ys): - classes = np.unique(ys) + ysTemp = label_normalization(np.copy(ys)) + classes = np.unique(ysTemp) n = len(classes) - D1 = np.zeros((n, len(ys))) + D1 = np.zeros((n, len(ysTemp))) # perform label propagation transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] @@ -970,18 +982,13 @@ class BaseTransport(BaseEstimator): # set nans to 0 transp[~ np.isfinite(transp)] = 0 - if np.min(classes) != 0: - ys = ys - np.min(classes) - classes = np.unique(ys) - for c in classes: - D1[int(c), ys == c] = 1 + D1[int(c), ysTemp == c] = 1 # compute transported samples transp_ys = np.dot(D1, transp) - return np.argmax(transp_ys,axis=0) - + return np.argmax(transp_ys, axis=0) def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): @@ -1066,9 +1073,10 @@ class BaseTransport(BaseEstimator): # check the necessary inputs parameters are here if check_params(yt=yt): - classes = np.unique(yt) + ytTemp = label_normalization(np.copy(yt)) + classes = np.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(yt))) + D1 = np.zeros((n, len(ytTemp))) # perform label propagation transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] @@ -1076,17 +1084,13 @@ class BaseTransport(BaseEstimator): # set nans to 0 transp[~ np.isfinite(transp)] = 0 - if np.min(classes) != 0: - yt = yt - np.min(classes) - classes = np.unique(yt) - for c in classes: - D1[int(c), yt == c] = 1 + D1[int(c), ytTemp == c] = 1 # compute transported samples transp_ys = np.dot(D1, transp.T) - return np.argmax(transp_ys,axis=0) + return np.argmax(transp_ys, axis=0) class LinearTransport(BaseTransport): @@ -2163,7 +2167,7 @@ class JCPOTTransport(BaseTransport): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels ys to obtain target labels + """Propagate source labels ys to obtain target labels as in [27] Parameters ---------- @@ -2178,11 +2182,12 @@ class JCPOTTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(ys=ys): - yt = np.zeros((len(np.unique(np.concatenate(ys))),self.xt_.shape[0])) + yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0])) for i in range(len(ys)): - classes = np.unique(ys[i]) + ysTemp = label_normalization(np.copy(ys[i])) + classes = np.unique(ysTemp) n = len(classes) - ns = len(ys[i]) + ns = len(ysTemp) # perform label propagation transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] @@ -2195,16 +2200,13 @@ class JCPOTTransport(BaseTransport): else: D1 = np.zeros((n, ns)) - if np.min(classes) != 0: - ys = ys - np.min(classes) - classes = np.unique(ys) - for c in classes: - D1[int(c), ys == c] = 1 + D1[int(c), ysTemp == c] = 1 + # compute transported samples - yt = yt + np.dot(D1, transp)/len(ys) + yt = yt + np.dot(D1, transp) / len(ys) - return np.argmax(yt,axis=0) + return np.argmax(yt, axis=0) def inverse_transform_labels(self, yt=None): """Propagate source labels ys to obtain target labels @@ -2223,16 +2225,13 @@ class JCPOTTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(yt=yt): transp_ys = [] - classes = np.unique(yt) + ytTemp = label_normalization(np.copy(yt)) + classes = np.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(yt))) - - if np.min(classes) != 0: - yt = yt - np.min(classes) - classes = np.unique(yt) + D1 = np.zeros((n, len(ytTemp))) for c in classes: - D1[int(c), yt == c] = 1 + D1[int(c), ytTemp == c] = 1 for i in range(len(self.xs_)): @@ -2243,6 +2242,6 @@ class JCPOTTransport(BaseTransport): transp[~ np.isfinite(transp)] = 0 # compute transported labels - transp_ys.append(np.argmax(np.dot(D1, transp.T),axis=0)) + transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0)) return transp_ys |