diff options
Diffstat (limited to 'ot/da.py')
-rw-r--r-- | ot/da.py | 40 |
1 files changed, 21 insertions, 19 deletions
@@ -956,8 +956,8 @@ class BaseTransport(BaseEstimator): Returns ------- - transp_ys : array-like, shape (n_target_samples,) - Estimated target labels. + transp_ys : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. References ---------- @@ -985,10 +985,10 @@ class BaseTransport(BaseEstimator): for c in classes: D1[int(c), ysTemp == c] = 1 - # compute transported samples + # compute propagated labels transp_ys = np.dot(D1, transp) - return np.argmax(transp_ys, axis=0) + return transp_ys.T def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): @@ -1066,8 +1066,8 @@ class BaseTransport(BaseEstimator): Returns ------- - transp_ys : array-like, shape (n_source_samples,) - Estimated source labels. + transp_ys : array-like, shape (n_source_samples, nb_classes) + Estimated soft source labels. """ # check the necessary inputs parameters are here @@ -1087,10 +1087,10 @@ class BaseTransport(BaseEstimator): for c in classes: D1[int(c), ytTemp == c] = 1 - # compute transported samples + # compute propagated samples transp_ys = np.dot(D1, transp.T) - return np.argmax(transp_ys, axis=0) + return transp_ys.T class LinearTransport(BaseTransport): @@ -2083,13 +2083,15 @@ class JCPOTTransport(BaseTransport): returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e, metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol, - verbose=self.verbose, log=self.log) + verbose=self.verbose, log=True) + + self.coupling_ = returned_[1]['gamma'] # deal with the value of log if self.log: - self.coupling_, self.proportions_, self.log_ = returned_ + self.proportions_, self.log_ = returned_ else: - self.coupling_, self.proportions_ = returned_ + self.proportions_ = returned_ self.log_ = dict() return self @@ -2176,8 +2178,8 @@ class JCPOTTransport(BaseTransport): Returns ------- - yt : array-like, shape (n_target_samples,) - Estimated target labels. + yt : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. """ # check the necessary inputs parameters are here @@ -2203,10 +2205,10 @@ class JCPOTTransport(BaseTransport): for c in classes: D1[int(c), ysTemp == c] = 1 - # compute transported samples + # compute propagated labels yt = yt + np.dot(D1, transp) / len(ys) - return np.argmax(yt, axis=0) + return yt.T def inverse_transform_labels(self, yt=None): """Propagate source labels ys to obtain target labels @@ -2218,8 +2220,8 @@ class JCPOTTransport(BaseTransport): Returns ------- - transp_ys : list of K array-like objects, shape K x (nk_source_samples,) - A list of estimated source labels + transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes) + A list of estimated soft source labels """ # check the necessary inputs parameters are here @@ -2241,7 +2243,7 @@ class JCPOTTransport(BaseTransport): # set nans to 0 transp[~ np.isfinite(transp)] = 0 - # compute transported labels - transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0)) + # compute propagated labels + transp_ys.append(np.dot(D1, transp.T).T) return transp_ys |