diff options
author | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-15 11:12:23 +0200 |
---|---|---|
committer | ievred <ievgen.redko@univ-st-etienne.fr> | 2020-04-15 11:12:23 +0200 |
commit | 749378a50abd763c87f5cf24a4b2e0dff2a6ec6a (patch) | |
tree | e8aeb51b7bcf3b48fead20fa44eee154da5b8d05 /ot | |
parent | 1a4c264cc9b2cb0bb89840ee9175177e86eef3ef (diff) |
fix soft labels, remove gammas from jcpot
Diffstat (limited to 'ot')
-rw-r--r-- | ot/bregman.py | 9 | ||||
-rw-r--r-- | ot/da.py | 40 |
2 files changed, 25 insertions, 24 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index c44c141..543dbaa 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1553,8 +1553,6 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Returns ------- - gamma : List of K (nsk x nt) ndarrays - Optimal transportation matrices for the given parameters for each pair of source and target domains h : (C,) ndarray proportion estimation in the target domain log : dict @@ -1574,7 +1572,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # log dictionary if log: - log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []} + log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []} K = [] M = [] @@ -1657,9 +1655,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, log['M'] = M log['D1'] = D1 log['D2'] = D2 - return K, bary, log + log['gamma'] = K + return bary, log else: - return K, bary + return bary def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', @@ -956,8 +956,8 @@ class BaseTransport(BaseEstimator): Returns ------- - transp_ys : array-like, shape (n_target_samples,) - Estimated target labels. + transp_ys : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. References ---------- @@ -985,10 +985,10 @@ class BaseTransport(BaseEstimator): for c in classes: D1[int(c), ysTemp == c] = 1 - # compute transported samples + # compute propagated labels transp_ys = np.dot(D1, transp) - return np.argmax(transp_ys, axis=0) + return transp_ys.T def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): @@ -1066,8 +1066,8 @@ class BaseTransport(BaseEstimator): Returns ------- - transp_ys : array-like, shape (n_source_samples,) - Estimated source labels. + transp_ys : array-like, shape (n_source_samples, nb_classes) + Estimated soft source labels. """ # check the necessary inputs parameters are here @@ -1087,10 +1087,10 @@ class BaseTransport(BaseEstimator): for c in classes: D1[int(c), ytTemp == c] = 1 - # compute transported samples + # compute propagated samples transp_ys = np.dot(D1, transp.T) - return np.argmax(transp_ys, axis=0) + return transp_ys.T class LinearTransport(BaseTransport): @@ -2083,13 +2083,15 @@ class JCPOTTransport(BaseTransport): returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e, metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol, - verbose=self.verbose, log=self.log) + verbose=self.verbose, log=True) + + self.coupling_ = returned_[1]['gamma'] # deal with the value of log if self.log: - self.coupling_, self.proportions_, self.log_ = returned_ + self.proportions_, self.log_ = returned_ else: - self.coupling_, self.proportions_ = returned_ + self.proportions_ = returned_ self.log_ = dict() return self @@ -2176,8 +2178,8 @@ class JCPOTTransport(BaseTransport): Returns ------- - yt : array-like, shape (n_target_samples,) - Estimated target labels. + yt : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. """ # check the necessary inputs parameters are here @@ -2203,10 +2205,10 @@ class JCPOTTransport(BaseTransport): for c in classes: D1[int(c), ysTemp == c] = 1 - # compute transported samples + # compute propagated labels yt = yt + np.dot(D1, transp) / len(ys) - return np.argmax(yt, axis=0) + return yt.T def inverse_transform_labels(self, yt=None): """Propagate source labels ys to obtain target labels @@ -2218,8 +2220,8 @@ class JCPOTTransport(BaseTransport): Returns ------- - transp_ys : list of K array-like objects, shape K x (nk_source_samples,) - A list of estimated source labels + transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes) + A list of estimated soft source labels """ # check the necessary inputs parameters are here @@ -2241,7 +2243,7 @@ class JCPOTTransport(BaseTransport): # set nans to 0 transp[~ np.isfinite(transp)] = 0 - # compute transported labels - transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0)) + # compute propagated labels + transp_ys.append(np.dot(D1, transp.T).T) return transp_ys |