summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py54
1 files changed, 37 insertions, 17 deletions
diff --git a/ot/da.py b/ot/da.py
index 6249f08..b881a8b 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -908,7 +908,8 @@ class BaseTransport(BaseEstimator):
at the class level in their ``__init__`` as explicit keyword
arguments (no ``*args`` or ``**kwargs``).
- fit method should:
+ the fit method should:
+
- estimate a cost matrix and store it in a `cost_` attribute
- estimate a coupling matrix and store it in a `coupling_`
attribute
@@ -933,7 +934,7 @@ class BaseTransport(BaseEstimator):
Xs : array-like, shape (n_source_samples, n_features)
The training input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The training class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -994,7 +995,7 @@ class BaseTransport(BaseEstimator):
Xs : array-like, shape (n_source_samples, n_features)
The training input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The class labels for training samples
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -1018,13 +1019,13 @@ class BaseTransport(BaseEstimator):
Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
+ The source input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The class labels for source samples
Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
+ The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
+ The class labels for target. If some target samples are unlabeled, fill the
yt's elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
@@ -1085,7 +1086,7 @@ class BaseTransport(BaseEstimator):
Parameters
----------
ys : array-like, shape (n_source_samples,)
- The class labels
+ The source class labels
Returns
-------
@@ -1125,18 +1126,18 @@ class BaseTransport(BaseEstimator):
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
batch_size=128):
- """Transports target samples Xt onto target samples Xs
+ """Transports target samples Xt onto source samples Xs
Parameters
----------
Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
+ The source input samples.
ys : array-like, shape (n_source_samples,)
- The class labels
+ The source class labels
Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
+ The target input samples.
yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
+ The target class labels. If some target samples are unlabeled, fill the
yt's elements with -1.
Warning: Note that, due to this convention -1 cannot be used as a
@@ -1227,7 +1228,6 @@ class BaseTransport(BaseEstimator):
class LinearTransport(BaseTransport):
-
""" OT linear operator between empirical distributions
The function estimates the optimal linear operator that aligns the two
@@ -1438,6 +1438,9 @@ class SinkhornTransport(BaseTransport):
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems (NIPS)
26, 2013
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_e=1., max_iter=1000,
@@ -1536,6 +1539,9 @@ class EMDTransport(BaseTransport):
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
"Optimal Transport for Domain Adaptation," in IEEE Transactions
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, metric="sqeuclidean", norm=None, log=False,
@@ -1643,7 +1649,9 @@ class SinkhornLpl1Transport(BaseTransport):
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
-
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_e=1., reg_cl=0.1,
@@ -1763,6 +1771,9 @@ class EMDLaplaceTransport(BaseTransport):
.. [2] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
"Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean",
@@ -1882,7 +1893,9 @@ class SinkhornL1l2Transport(BaseTransport):
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
Generalized conditional gradient: analysis of convergence
and applications. arXiv preprint arXiv:1510.06567.
-
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_e=1., reg_cl=0.1,
@@ -2174,7 +2187,9 @@ class UnbalancedSinkhornTransport(BaseTransport):
.. [1] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint
arXiv:1607.05816.
-
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
def __init__(self, reg_e=1., reg_m=0.1, method='sinkhorn',
@@ -2287,6 +2302,11 @@ class JCPOTTransport(BaseTransport):
International Conference on Artificial Intelligence and Statistics (AISTATS),
vol. 89, p.849-858, 2019.
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
+
+
"""
def __init__(self, reg_e=.1, max_iter=10,