diff options
-rw-r--r-- | ot/da.py | 75 | ||||
-rw-r--r-- | ot/utils.py | 22 | ||||
-rw-r--r-- | test/test_da.py | 1 |
3 files changed, 60 insertions, 38 deletions
@@ -16,7 +16,7 @@ import scipy.linalg as linalg from .bregman import sinkhorn, jcpot_barycenter from .lp import emd -from .utils import unif, dist, kernel, cost_normalization +from .utils import unif, dist, kernel, cost_normalization, label_normalization from .utils import check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg @@ -786,6 +786,9 @@ class BaseTransport(BaseEstimator): transform method should always get as input a Xs parameter inverse_transform method should always get as input a Xt parameter + + transform_labels method should always get as input a ys parameter + inverse_transform_labels method should always get as input a yt parameter """ def fit(self, Xs=None, ys=None, Xt=None, yt=None): @@ -944,7 +947,7 @@ class BaseTransport(BaseEstimator): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels ys to obtain estimated target labels + """Propagate source labels ys to obtain estimated target labels as in [27] Parameters ---------- @@ -955,14 +958,23 @@ class BaseTransport(BaseEstimator): ------- transp_ys : array-like, shape (n_target_samples,) Estimated target labels. + + References + ---------- + + .. [27] Ievgen Redko, Nicolas Courty, RĂ©mi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. + """ # check the necessary inputs parameters are here if check_params(ys=ys): - classes = np.unique(ys) + ysTemp = label_normalization(np.copy(ys)) + classes = np.unique(ysTemp) n = len(classes) - D1 = np.zeros((n, len(ys))) + D1 = np.zeros((n, len(ysTemp))) # perform label propagation transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] @@ -970,18 +982,13 @@ class BaseTransport(BaseEstimator): # set nans to 0 transp[~ np.isfinite(transp)] = 0 - if np.min(classes) != 0: - ys = ys - np.min(classes) - classes = np.unique(ys) - for c in classes: - D1[int(c), ys == c] = 1 + D1[int(c), ysTemp == c] = 1 # compute transported samples transp_ys = np.dot(D1, transp) - return np.argmax(transp_ys,axis=0) - + return np.argmax(transp_ys, axis=0) def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): @@ -1066,9 +1073,10 @@ class BaseTransport(BaseEstimator): # check the necessary inputs parameters are here if check_params(yt=yt): - classes = np.unique(yt) + ytTemp = label_normalization(np.copy(yt)) + classes = np.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(yt))) + D1 = np.zeros((n, len(ytTemp))) # perform label propagation transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] @@ -1076,17 +1084,13 @@ class BaseTransport(BaseEstimator): # set nans to 0 transp[~ np.isfinite(transp)] = 0 - if np.min(classes) != 0: - yt = yt - np.min(classes) - classes = np.unique(yt) - for c in classes: - D1[int(c), yt == c] = 1 + D1[int(c), ytTemp == c] = 1 # compute transported samples transp_ys = np.dot(D1, transp.T) - return np.argmax(transp_ys,axis=0) + return np.argmax(transp_ys, axis=0) class LinearTransport(BaseTransport): @@ -2163,7 +2167,7 @@ class JCPOTTransport(BaseTransport): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels ys to obtain target labels + """Propagate source labels ys to obtain target labels as in [27] Parameters ---------- @@ -2178,11 +2182,12 @@ class JCPOTTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(ys=ys): - yt = np.zeros((len(np.unique(np.concatenate(ys))),self.xt_.shape[0])) + yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0])) for i in range(len(ys)): - classes = np.unique(ys[i]) + ysTemp = label_normalization(np.copy(ys[i])) + classes = np.unique(ysTemp) n = len(classes) - ns = len(ys[i]) + ns = len(ysTemp) # perform label propagation transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] @@ -2195,16 +2200,13 @@ class JCPOTTransport(BaseTransport): else: D1 = np.zeros((n, ns)) - if np.min(classes) != 0: - ys = ys - np.min(classes) - classes = np.unique(ys) - for c in classes: - D1[int(c), ys == c] = 1 + D1[int(c), ysTemp == c] = 1 + # compute transported samples - yt = yt + np.dot(D1, transp)/len(ys) + yt = yt + np.dot(D1, transp) / len(ys) - return np.argmax(yt,axis=0) + return np.argmax(yt, axis=0) def inverse_transform_labels(self, yt=None): """Propagate source labels ys to obtain target labels @@ -2223,16 +2225,13 @@ class JCPOTTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(yt=yt): transp_ys = [] - classes = np.unique(yt) + ytTemp = label_normalization(np.copy(yt)) + classes = np.unique(ytTemp) n = len(classes) - D1 = np.zeros((n, len(yt))) - - if np.min(classes) != 0: - yt = yt - np.min(classes) - classes = np.unique(yt) + D1 = np.zeros((n, len(ytTemp))) for c in classes: - D1[int(c), yt == c] = 1 + D1[int(c), ytTemp == c] = 1 for i in range(len(self.xs_)): @@ -2243,6 +2242,6 @@ class JCPOTTransport(BaseTransport): transp[~ np.isfinite(transp)] = 0 # compute transported labels - transp_ys.append(np.argmax(np.dot(D1, transp.T),axis=0)) + transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0)) return transp_ys diff --git a/ot/utils.py b/ot/utils.py index b71458b..c154f99 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -200,6 +200,28 @@ def dots(*args): return reduce(np.dot, args) +def label_normalization(y, start=0): + """ Transform labels to start at a given value + + Parameters + ---------- + y : array-like, shape (n, ) + The vector of labels to be normalized. + start : int + Desired value for the smallest label in y (default=0) + + Returns + ------- + y : array-like, shape (n1, ) + The input vector of labels normalized according to given start value. + """ + + diff = np.min(np.unique(y)) - start + if diff != 0: + y -= diff + return y + + def fun(f, q_in, q_out): """ Utility function for parmap with no serializing problems """ while True: diff --git a/test/test_da.py b/test/test_da.py index 4eb6de0..d96046d 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -650,6 +650,7 @@ def test_jcpot_transport_class(): transp_ys = otda.inverse_transform_labels(yt) [assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)] + def test_jcpot_barycenter(): """test_jcpot_barycenter """ |