summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-09 12:40:09 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-09 12:40:09 +0900
commit619bb41a18a542ce768fd4ce3eb9240e9ad6650e (patch)
treecc60cccc304a8d9fcad31d42aab40513b1dce48d /ot
parente58cd780ccf87736265e4e1a39afa3a167325ccc (diff)
parent62dcfbfb78a2be24379cd5cdb4aec70d8c4befaa (diff)
Merge remote-tracking branch 'upstream/master' into ot_dual_variables
Diffstat (limited to 'ot')
-rw-r--r--ot/da.py74
1 files changed, 55 insertions, 19 deletions
diff --git a/ot/da.py b/ot/da.py
index 564c7b7..1d3d0ba 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -966,8 +966,12 @@ class BaseTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -989,7 +993,7 @@ class BaseTransport(BaseEstimator):
# assumes labeled source samples occupy the first rows
# and labeled target samples occupy the first columns
- classes = np.unique(ys)
+ classes = [c for c in np.unique(ys) if c != -1]
for c in classes:
idx_s = np.where((ys != c) & (ys != -1))
idx_t = np.where(yt == c)
@@ -1023,8 +1027,12 @@ class BaseTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1045,8 +1053,12 @@ class BaseTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform
@@ -1110,8 +1122,12 @@ class BaseTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform
@@ -1241,8 +1257,12 @@ class SinkhornTransport(BaseTransport):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1333,8 +1353,12 @@ class EMDTransport(BaseTransport):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1434,8 +1458,12 @@ class SinkhornLpl1Transport(BaseTransport):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1545,8 +1573,12 @@ class SinkhornL1l2Transport(BaseTransport):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------
@@ -1662,8 +1694,12 @@ class MappingTransport(BaseEstimator):
The class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
- yt : array-like, shape (n_labeled_target_samples,)
- The class labels
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
Returns
-------