summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-15 11:12:23 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-15 11:12:23 +0200
commit749378a50abd763c87f5cf24a4b2e0dff2a6ec6a (patch)
treee8aeb51b7bcf3b48fead20fa44eee154da5b8d05 /ot
parent1a4c264cc9b2cb0bb89840ee9175177e86eef3ef (diff)
fix soft labels, remove gammas from jcpot
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py9
-rw-r--r--ot/da.py40
2 files changed, 25 insertions, 24 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index c44c141..543dbaa 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1553,8 +1553,6 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Returns
-------
- gamma : List of K (nsk x nt) ndarrays
- Optimal transportation matrices for the given parameters for each pair of source and target domains
h : (C,) ndarray
proportion estimation in the target domain
log : dict
@@ -1574,7 +1572,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
# log dictionary
if log:
- log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []}
+ log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []}
K = []
M = []
@@ -1657,9 +1655,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
log['M'] = M
log['D1'] = D1
log['D2'] = D2
- return K, bary, log
+ log['gamma'] = K
+ return bary, log
else:
- return K, bary
+ return bary
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
diff --git a/ot/da.py b/ot/da.py
index 4318c0d..30e5a61 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -956,8 +956,8 @@ class BaseTransport(BaseEstimator):
Returns
-------
- transp_ys : array-like, shape (n_target_samples,)
- Estimated target labels.
+ transp_ys : array-like, shape (n_target_samples, nb_classes)
+ Estimated soft target labels.
References
----------
@@ -985,10 +985,10 @@ class BaseTransport(BaseEstimator):
for c in classes:
D1[int(c), ysTemp == c] = 1
- # compute transported samples
+ # compute propagated labels
transp_ys = np.dot(D1, transp)
- return np.argmax(transp_ys, axis=0)
+ return transp_ys.T
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
@@ -1066,8 +1066,8 @@ class BaseTransport(BaseEstimator):
Returns
-------
- transp_ys : array-like, shape (n_source_samples,)
- Estimated source labels.
+ transp_ys : array-like, shape (n_source_samples, nb_classes)
+ Estimated soft source labels.
"""
# check the necessary inputs parameters are here
@@ -1087,10 +1087,10 @@ class BaseTransport(BaseEstimator):
for c in classes:
D1[int(c), ytTemp == c] = 1
- # compute transported samples
+ # compute propagated samples
transp_ys = np.dot(D1, transp.T)
- return np.argmax(transp_ys, axis=0)
+ return transp_ys.T
class LinearTransport(BaseTransport):
@@ -2083,13 +2083,15 @@ class JCPOTTransport(BaseTransport):
returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e,
metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol,
- verbose=self.verbose, log=self.log)
+ verbose=self.verbose, log=True)
+
+ self.coupling_ = returned_[1]['gamma']
# deal with the value of log
if self.log:
- self.coupling_, self.proportions_, self.log_ = returned_
+ self.proportions_, self.log_ = returned_
else:
- self.coupling_, self.proportions_ = returned_
+ self.proportions_ = returned_
self.log_ = dict()
return self
@@ -2176,8 +2178,8 @@ class JCPOTTransport(BaseTransport):
Returns
-------
- yt : array-like, shape (n_target_samples,)
- Estimated target labels.
+ yt : array-like, shape (n_target_samples, nb_classes)
+ Estimated soft target labels.
"""
# check the necessary inputs parameters are here
@@ -2203,10 +2205,10 @@ class JCPOTTransport(BaseTransport):
for c in classes:
D1[int(c), ysTemp == c] = 1
- # compute transported samples
+ # compute propagated labels
yt = yt + np.dot(D1, transp) / len(ys)
- return np.argmax(yt, axis=0)
+ return yt.T
def inverse_transform_labels(self, yt=None):
"""Propagate source labels ys to obtain target labels
@@ -2218,8 +2220,8 @@ class JCPOTTransport(BaseTransport):
Returns
-------
- transp_ys : list of K array-like objects, shape K x (nk_source_samples,)
- A list of estimated source labels
+ transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes)
+ A list of estimated soft source labels
"""
# check the necessary inputs parameters are here
@@ -2241,7 +2243,7 @@ class JCPOTTransport(BaseTransport):
# set nans to 0
transp[~ np.isfinite(transp)] = 0
- # compute transported labels
- transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0))
+ # compute propagated labels
+ transp_ys.append(np.dot(D1, transp.T).T)
return transp_ys