summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2020-04-24 11:28:21 +0200
committerRémi Flamary <remi.flamary@gmail.com>2020-04-24 11:28:21 +0200
commit6931f78c7e5b1da4f62a1a85d87409f3f95029c7 (patch)
tree19900dad3007382a2d6189f07a2b9966e84f7eda /ot/da.py
parent8599e720d5f438e2aaf5c635883e64deb026f3ce (diff)
better documentation
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/ot/da.py b/ot/da.py
index 6249f08..4d2bb6c 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -908,7 +908,8 @@ class BaseTransport(BaseEstimator):
at the class level in their ``__init__`` as explicit keyword
arguments (no ``*args`` or ``**kwargs``).
- fit method should:
+ the fit method should:
+
- estimate a cost matrix and store it in a `cost_` attribute
- estimate a coupling matrix and store it in a `coupling_`
attribute
@@ -933,7 +934,7 @@ class BaseTransport(BaseEstimator):
Xs : array-like, shape (n_source_samples, n_features)
The training input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The training class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -994,7 +995,7 @@ class BaseTransport(BaseEstimator):
Xs : array-like, shape (n_source_samples, n_features)
The training input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The class labels for training samples
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -1018,13 +1019,13 @@ class BaseTransport(BaseEstimator):
Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
+ The source input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The class labels for source samples
Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
+ The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
+ The class labels for target. If some target samples are unlabeled, fill the
yt's elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
@@ -1085,7 +1086,7 @@ class BaseTransport(BaseEstimator):
Parameters
----------
ys : array-like, shape (n_source_samples,)
- The class labels
+ The source class labels
Returns
-------
@@ -1125,18 +1126,18 @@ class BaseTransport(BaseEstimator):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples Xt onto target samples Xs
+ """Transports target samples Xt onto source samples Xs
Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
+ The source input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The source class labels
Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
+ The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
+ The target class labels. If some target samples are unlabeled, fill the
yt's elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
@@ -1227,7 +1228,6 @@ class BaseTransport(BaseEstimator):
class LinearTransport(BaseTransport):
-
""" OT linear operator between empirical distributions
The function estimates the optimal linear operator that aligns the two