summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-15 11:12:23 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-15 11:12:23 +0200
commit749378a50abd763c87f5cf24a4b2e0dff2a6ec6a (patch)
treee8aeb51b7bcf3b48fead20fa44eee154da5b8d05 /ot/da.py
parent1a4c264cc9b2cb0bb89840ee9175177e86eef3ef (diff)
fix soft labels, remove gammas from jcpot
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py40
1 files changed, 21 insertions, 19 deletions
diff --git a/ot/da.py b/ot/da.py
index 4318c0d..30e5a61 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -956,8 +956,8 @@ class BaseTransport(BaseEstimator):
Returns
-------
- transp_ys : array-like, shape (n_target_samples,)
- Estimated target labels.
+ transp_ys : array-like, shape (n_target_samples, nb_classes)
+ Estimated soft target labels.
References
----------
@@ -985,10 +985,10 @@ class BaseTransport(BaseEstimator):
for c in classes:
D1[int(c), ysTemp == c] = 1
- # compute transported samples
+ # compute propagated labels
transp_ys = np.dot(D1, transp)
- return np.argmax(transp_ys, axis=0)
+ return transp_ys.T
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
@@ -1066,8 +1066,8 @@ class BaseTransport(BaseEstimator):
Returns
-------
- transp_ys : array-like, shape (n_source_samples,)
- Estimated source labels.
+ transp_ys : array-like, shape (n_source_samples, nb_classes)
+ Estimated soft source labels.
"""
# check the necessary inputs parameters are here
@@ -1087,10 +1087,10 @@ class BaseTransport(BaseEstimator):
for c in classes:
D1[int(c), ytTemp == c] = 1
- # compute transported samples
+ # compute propagated samples
transp_ys = np.dot(D1, transp.T)
- return np.argmax(transp_ys, axis=0)
+ return transp_ys.T
class LinearTransport(BaseTransport):
@@ -2083,13 +2083,15 @@ class JCPOTTransport(BaseTransport):
returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e,
metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol,
- verbose=self.verbose, log=self.log)
+ verbose=self.verbose, log=True)
+
+ self.coupling_ = returned_[1]['gamma']
# deal with the value of log
if self.log:
- self.coupling_, self.proportions_, self.log_ = returned_
+ self.proportions_, self.log_ = returned_
else:
- self.coupling_, self.proportions_ = returned_
+ self.proportions_ = returned_
self.log_ = dict()
return self
@@ -2176,8 +2178,8 @@ class JCPOTTransport(BaseTransport):
Returns
-------
- yt : array-like, shape (n_target_samples,)
- Estimated target labels.
+ yt : array-like, shape (n_target_samples, nb_classes)
+ Estimated soft target labels.
"""
# check the necessary inputs parameters are here
@@ -2203,10 +2205,10 @@ class JCPOTTransport(BaseTransport):
for c in classes:
D1[int(c), ysTemp == c] = 1
- # compute transported samples
+ # compute propagated labels
yt = yt + np.dot(D1, transp) / len(ys)
- return np.argmax(yt, axis=0)
+ return yt.T
def inverse_transform_labels(self, yt=None):
"""Propagate source labels ys to obtain target labels
@@ -2218,8 +2220,8 @@ class JCPOTTransport(BaseTransport):
Returns
-------
- transp_ys : list of K array-like objects, shape K x (nk_source_samples,)
- A list of estimated source labels
+ transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes)
+ A list of estimated soft source labels
"""
# check the necessary inputs parameters are here
@@ -2241,7 +2243,7 @@ class JCPOTTransport(BaseTransport):
# set nans to 0
transp[~ np.isfinite(transp)] = 0
- # compute transported labels
- transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0))
+ # compute propagated labels
+ transp_ys.append(np.dot(D1, transp.T).T)
return transp_ys