summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/ot/da.py b/ot/da.py
index 6249f08..4d2bb6c 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -908,7 +908,8 @@ class BaseTransport(BaseEstimator):
at the class level in their ``__init__`` as explicit keyword
arguments (no ``*args`` or ``**kwargs``).
- fit method should:
+ the fit method should:
+
- estimate a cost matrix and store it in a `cost_` attribute
- estimate a coupling matrix and store it in a `coupling_`
attribute
@@ -933,7 +934,7 @@ class BaseTransport(BaseEstimator):
Xs : array-like, shape (n_source_samples, n_features)
The training input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The training class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -994,7 +995,7 @@ class BaseTransport(BaseEstimator):
Xs : array-like, shape (n_source_samples, n_features)
The training input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The class labels for training samples
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -1018,13 +1019,13 @@ class BaseTransport(BaseEstimator):
Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
+ The source input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The class labels for source samples
Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
+ The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
+ The class labels for target. If some target samples are unlabeled, fill the
yt's elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
@@ -1085,7 +1086,7 @@ class BaseTransport(BaseEstimator):
Parameters
----------
ys : array-like, shape (n_source_samples,)
- The class labels
+ The source class labels
Returns
-------
@@ -1125,18 +1126,18 @@ class BaseTransport(BaseEstimator):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples Xt onto target samples Xs
+ """Transports target samples Xt onto source samples Xs
Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
+ The source input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The source class labels
Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
+ The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
+ The target class labels. If some target samples are unlabeled, fill the
yt's elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
@@ -1227,7 +1228,6 @@ class BaseTransport(BaseEstimator):
class LinearTransport(BaseTransport):
-
""" OT linear operator between empirical distributions
The function estimates the optimal linear operator that aligns the two