diff options
-rw-r--r-- | ot/da.py | 28 |
1 files changed, 14 insertions, 14 deletions
@@ -908,7 +908,8 @@ class BaseTransport(BaseEstimator): at the class level in their ``__init__`` as explicit keyword arguments (no ``*args`` or ``**kwargs``). - fit method should: + the fit method should: + - estimate a cost matrix and store it in a `cost_` attribute - estimate a coupling matrix and store it in a `coupling_` attribute @@ -933,7 +934,7 @@ class BaseTransport(BaseEstimator): Xs : array-like, shape (n_source_samples, n_features) The training input samples. ys : array-like, shape (n_source_samples,) - The class labels + The training class labels Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) @@ -994,7 +995,7 @@ class BaseTransport(BaseEstimator): Xs : array-like, shape (n_source_samples, n_features) The training input samples. ys : array-like, shape (n_source_samples,) - The class labels + The class labels for training samples Xt : array-like, shape (n_target_samples, n_features) The training input samples. yt : array-like, shape (n_target_samples,) @@ -1018,13 +1019,13 @@ class BaseTransport(BaseEstimator): Parameters ---------- Xs : array-like, shape (n_source_samples, n_features) - The training input samples. + The source input samples. ys : array-like, shape (n_source_samples,) - The class labels + The class labels for source samples Xt : array-like, shape (n_target_samples, n_features) - The training input samples. + The target input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the + The class labels for target. If some target samples are unlabeled, fill the yt's elements with -1. Warning: Note that, due to this convention -1 cannot be used as a @@ -1085,7 +1086,7 @@ class BaseTransport(BaseEstimator): Parameters ---------- ys : array-like, shape (n_source_samples,) - The class labels + The source class labels Returns ------- @@ -1125,18 +1126,18 @@ class BaseTransport(BaseEstimator): def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports target samples Xt onto target samples Xs + """Transports target samples Xt onto source samples Xs Parameters ---------- Xs : array-like, shape (n_source_samples, n_features) - The training input samples. + The source input samples. ys : array-like, shape (n_source_samples,) - The class labels + The source class labels Xt : array-like, shape (n_target_samples, n_features) - The training input samples. + The target input samples. yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the + The target class labels. If some target samples are unlabeled, fill the yt's elements with -1. Warning: Note that, due to this convention -1 cannot be used as a @@ -1227,7 +1228,6 @@ class BaseTransport(BaseEstimator): class LinearTransport(BaseTransport): - """ OT linear operator between empirical distributions The function estimates the optimal linear operator that aligns the two |