summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
Diffstat (limited to 'ot')
-rw-r--r--ot/da.py969
-rw-r--r--ot/utils.py240
2 files changed, 1167 insertions, 42 deletions
diff --git a/ot/da.py b/ot/da.py
index 4f9bce5..78dc150 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -10,21 +10,27 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
+
from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel
+from .utils import check_params, deprecated, BaseEstimator
from .optim import cg
from .optim import gcg
-def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
+def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
+ numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
+ log=False):
"""
- Solve the entropic regularization optimal transport problem with nonconvex group lasso regularization
+ Solve the entropic regularization optimal transport problem with nonconvex
+ group lasso regularization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
+ + \eta \Omega_g(\gamma)
s.t. \gamma 1 = a
@@ -34,11 +40,16 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
where :
- M is the (ns,nt) metric cost matrix
- - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
+ - :math:`\Omega_e` is the entropic regularization term
+ :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega_g` is the group lasso regulaization term
+ :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
+ where :math:`\mathcal{I}_c` are the index of samples from class c
+ in the source domain.
- a and b are source and target weights (sum to 1)
- The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
+ The algorithm used for solving the problem is the generalised conditional
+ gradient as proposed in [5]_ [7]_
Parameters
@@ -78,8 +89,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
References
----------
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
- .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
See Also
--------
@@ -114,14 +130,18 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
return transp
-def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerItermax=200, stopInnerThr=1e-9, verbose=False, log=False):
+def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
+ numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
+ log=False):
"""
- Solve the entropic regularization optimal transport problem with group lasso regularization
+ Solve the entropic regularization optimal transport problem with group
+ lasso regularization
The function solves the following optimization problem:
.. math::
- \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+ \eta \Omega_g(\gamma)
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)+
+ \eta \Omega_g(\gamma)
s.t. \gamma 1 = a
@@ -131,11 +151,16 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
where :
- M is the (ns,nt) metric cost matrix
- - :math:`\Omega_e` is the entropic regularization term :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- - :math:`\Omega_g` is the group lasso regulaization term :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2` where :math:`\mathcal{I}_c` are the index of samples from class c in the source domain.
+ - :math:`\Omega_e` is the entropic regularization term
+ :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega_g` is the group lasso regulaization term
+ :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^2`
+ where :math:`\mathcal{I}_c` are the index of samples from class
+ c in the source domain.
- a and b are source and target weights (sum to 1)
- The algorithm used for solving the problem is the generalised conditional gradient as proposed in [5]_ [7]_
+ The algorithm used for solving the problem is the generalised conditional
+ gradient as proposed in [5]_ [7]_
Parameters
@@ -175,8 +200,12 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
References
----------
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
- .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence and
+ applications. arXiv preprint arXiv:1510.06567.
See Also
--------
@@ -203,16 +232,22 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, numInnerIte
W[labels_a == lab, i] = temp / n
return W
- return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax, numInnerItermax=numInnerItermax, stopThr=stopInnerThr, verbose=verbose, log=log)
+ return gcg(a, b, M, reg, eta, f, df, G0=None, numItermax=numItermax,
+ numInnerItermax=numInnerItermax, stopThr=stopInnerThr,
+ verbose=verbose, log=log)
-def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs):
+def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
+ verbose2=False, numItermax=100, numInnerItermax=10,
+ stopInnerThr=1e-6, stopThr=1e-5, log=False,
+ **kwargs):
"""Joint OT and linear mapping estimation as proposed in [8]
The function solves the following optimization problem:
.. math::
- \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L -I\|^2_F
+ \min_{\gamma,L}\quad \|L(X_s) -n_s\gamma X_t\|^2_F +
+ \mu<\gamma,M>_F + \eta \|L -I\|^2_F
s.t. \gamma 1 = a
@@ -221,8 +256,10 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
\gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in Xs and Xt (scaled by ns)
- - :math:`L` is a dxd linear operator that approximates the barycentric mapping
+ - M is the (ns,nt) squared euclidean cost matrix between samples in
+ Xs and Xt (scaled by ns)
+ - :math:`L` is a dxd linear operator that approximates the barycentric
+ mapping
- :math:`I` is the identity matrix (neutral linear mapping)
- a and b are uniform source and target weights
@@ -277,7 +314,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
References
----------
- .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
See Also
--------
@@ -384,13 +423,18 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
return G, L
-def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigma=1, bias=False, verbose=False, verbose2=False, numItermax=100, numInnerItermax=10, stopInnerThr=1e-6, stopThr=1e-5, log=False, **kwargs):
+def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
+ sigma=1, bias=False, verbose=False, verbose2=False,
+ numItermax=100, numInnerItermax=10,
+ stopInnerThr=1e-6, stopThr=1e-5, log=False,
+ **kwargs):
"""Joint OT and nonlinear mapping estimation with kernels as proposed in [8]
The function solves the following optimization problem:
.. math::
- \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
+ \min_{\gamma,L\in\mathcal{H}}\quad \|L(X_s) -
+ n_s\gamma X_t\|^2_F + \mu<\gamma,M>_F + \eta \|L\|^2_\mathcal{H}
s.t. \gamma 1 = a
@@ -399,8 +443,10 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm
\gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in Xs and Xt (scaled by ns)
- - :math:`L` is a ns x d linear operator on a kernel matrix that approximates the barycentric mapping
+ - M is the (ns,nt) squared euclidean cost matrix between samples in
+ Xs and Xt (scaled by ns)
+ - :math:`L` is a ns x d linear operator on a kernel matrix that
+ approximates the barycentric mapping
- a and b are uniform source and target weights
The problem consist in solving jointly an optimal transport matrix
@@ -458,7 +504,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm
References
----------
- .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard, "Mapping estimation for discrete optimal transport", Neural Information Processing Systems (NIPS), 2016.
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
See Also
--------
@@ -585,6 +633,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', sigm
return G, L
+@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
+ "removed in 0.5"
+ "\n\tfor standard transport use class EMDTransport instead.")
class OTDA(object):
"""Class for domain adaptation with optimal transport as proposed in [5]
@@ -593,7 +644,9 @@ class OTDA(object):
References
----------
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions on
+ Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
"""
@@ -606,7 +659,8 @@ class OTDA(object):
self.computed = False
def fit(self, xs, xt, ws=None, wt=None, norm=None):
- """ Fit domain adaptation between samples is xs and xt (with optional weights)"""
+ """Fit domain adaptation between samples is xs and xt
+ (with optional weights)"""
self.xs = xs
self.xt = xt
@@ -669,7 +723,9 @@ class OTDA(object):
References
----------
- .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
+ .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
+ Regularized discrete optimal transport. SIAM Journal on Imaging
+ Sciences, 7(3), 1853-1882.
"""
if direction > 0: # >0 then source to target
@@ -706,12 +762,19 @@ class OTDA(object):
self.M = np.log(1 + np.log(1 + self.M))
+@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class SinkhornTransport instead.")
class OTDA_sinkhorn(OTDA):
- """Class for domain adaptation with optimal transport with entropic regularization"""
+ """Class for domain adaptation with optimal transport with entropic
+ regularization
+
+
+ """
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
- """ Fit regularized domain adaptation between samples is xs and xt (with optional weights)"""
+ """Fit regularized domain adaptation between samples is xs and xt
+ (with optional weights)"""
self.xs = xs
self.xt = xt
@@ -729,12 +792,18 @@ class OTDA_sinkhorn(OTDA):
self.computed = True
+@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class SinkhornLpl1Transport instead.")
class OTDA_lpl1(OTDA):
- """Class for domain adaptation with optimal transport with entropic and group regularization"""
+ """Class for domain adaptation with optimal transport with entropic and
+ group regularization"""
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
- """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit parameters"""
+ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
+ **kwargs):
+ """Fit regularized domain adaptation between samples is xs and xt
+ (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
+ parameters"""
self.xs = xs
self.xt = xt
@@ -752,12 +821,18 @@ class OTDA_lpl1(OTDA):
self.computed = True
+@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class SinkhornL1l2Transport instead.")
class OTDA_l1l2(OTDA):
- """Class for domain adaptation with optimal transport with entropic and group lasso regularization"""
+ """Class for domain adaptation with optimal transport with entropic
+ and group lasso regularization"""
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None, **kwargs):
- """ Fit regularized domain adaptation between samples is xs and xt (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit parameters"""
+ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
+ **kwargs):
+ """Fit regularized domain adaptation between samples is xs and xt
+ (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
+ parameters"""
self.xs = xs
self.xt = xt
@@ -775,9 +850,13 @@ class OTDA_l1l2(OTDA):
self.computed = True
+@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class MappingTransport instead.")
class OTDA_mapping_linear(OTDA):
- """Class for optimal transport with joint linear mapping estimation as in [8]"""
+ """Class for optimal transport with joint linear mapping estimation as in
+ [8]
+ """
def __init__(self):
""" Class initialization"""
@@ -818,11 +897,15 @@ class OTDA_mapping_linear(OTDA):
return None
+@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class MappingTransport instead.")
class OTDA_mapping_kernel(OTDA_mapping_linear):
- """Class for optimal transport with joint nonlinear mapping estimation as in [8]"""
+ """Class for optimal transport with joint nonlinear mapping
+ estimation as in [8]"""
- def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian', sigma=1, **kwargs):
+ def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian',
+ sigma=1, **kwargs):
""" Fit domain adaptation between samples is xs and xt """
self.xs = xs
self.xt = xt
@@ -843,10 +926,812 @@ class OTDA_mapping_kernel(OTDA_mapping_linear):
if self.computed:
K = kernel(
- x, self.xs, method=self.kernel, sigma=self.sigma, **self.kwargs)
+ x, self.xs, method=self.kernel, sigma=self.sigma,
+ **self.kwargs)
if self.bias:
K = np.hstack((K, np.ones((x.shape[0], 1))))
return K.dot(self.L)
else:
print("Warning, model not fitted yet, returning None")
return None
+
+
+def distribution_estimation_uniform(X):
+ """estimates a uniform distribution from an array of samples X
+
+ Parameters
+ ----------
+ X : array-like, shape (n_samples, n_features)
+ The array of samples
+
+ Returns
+ -------
+ mu : array-like, shape (n_samples,)
+ The uniform distribution estimated from X
+ """
+
+ return unif(X.shape[0])
+
+
+class BaseTransport(BaseEstimator):
+ """Base class for OTDA objects
+
+ Notes
+ -----
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
+
+ fit method should:
+ - estimate a cost matrix and store it in a `cost_` attribute
+ - estimate a coupling matrix and store it in a `coupling_`
+ attribute
+ - estimate distributions from source and target data and store them in
+ mu_s and mu_t attributes
+ - store Xs and Xt in attributes to be used later on in transform and
+ inverse_transform methods
+
+ transform method should always get as input a Xs parameter
+ inverse_transform method should always get as input a Xt parameter
+ """
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_labeled_target_samples,)
+ The class labels
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ # pairwise distance
+ self.cost_ = dist(Xs, Xt, metric=self.metric)
+
+ if (ys is not None) and (yt is not None):
+
+ if self.limit_max != np.infty:
+ self.limit_max = self.limit_max * np.max(self.cost_)
+
+ # assumes labeled source samples occupy the first rows
+ # and labeled target samples occupy the first columns
+ classes = np.unique(ys)
+ for c in classes:
+ idx_s = np.where((ys != c) & (ys != -1))
+ idx_t = np.where(yt == c)
+
+ # all the coefficients corresponding to a source sample
+ # and a target sample :
+ # with different labels get a infinite
+ for j in idx_t[0]:
+ self.cost_[idx_s[0], j] = self.limit_max
+
+ # distribution estimation
+ self.mu_s = self.distribution_estimation(Xs)
+ self.mu_t = self.distribution_estimation(Xt)
+
+ # store arrays of samples
+ self.xs_ = Xs
+ self.xt_ = Xt
+
+ return self
+
+ def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt) and transports source samples Xs onto target
+ ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_labeled_target_samples,)
+ The class labels
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The source samples samples.
+ """
+
+ return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt)
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_labeled_target_samples,)
+ The class labels
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The transport source samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ if np.array_equal(self.xs_, Xs):
+
+ # perform standard barycentric mapping
+ transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ # compute transported samples
+ transp_Xs = np.dot(transp, self.xt_)
+ else:
+ # perform out of sample mapping
+ indices = np.arange(Xs.shape[0])
+ batch_ind = [
+ indices[i:i + batch_size]
+ for i in range(0, len(indices), batch_size)]
+
+ transp_Xs = []
+ for bi in batch_ind:
+
+ # get the nearest neighbor in the source domain
+ D0 = dist(Xs[bi], self.xs_)
+ idx = np.argmin(D0, axis=1)
+
+ # transport the source samples
+ transp = self.coupling_ / np.sum(
+ self.coupling_, 1)[:, None]
+ transp[~ np.isfinite(transp)] = 0
+ transp_Xs_ = np.dot(transp, self.xt_)
+
+ # define the transported points
+ transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :]
+
+ transp_Xs.append(transp_Xs_)
+
+ transp_Xs = np.concatenate(transp_Xs, axis=0)
+
+ return transp_Xs
+
+ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
+ batch_size=128):
+ """Transports target samples Xt onto target samples Xs
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_labeled_target_samples,)
+ The class labels
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+
+ Returns
+ -------
+ transp_Xt : array-like, shape (n_source_samples, n_features)
+ The transported target samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xt=Xt):
+
+ if np.array_equal(self.xt_, Xt):
+
+ # perform standard barycentric mapping
+ transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None]
+
+ # set nans to 0
+ transp_[~ np.isfinite(transp_)] = 0
+
+ # compute transported samples
+ transp_Xt = np.dot(transp_, self.xs_)
+ else:
+ # perform out of sample mapping
+ indices = np.arange(Xt.shape[0])
+ batch_ind = [
+ indices[i:i + batch_size]
+ for i in range(0, len(indices), batch_size)]
+
+ transp_Xt = []
+ for bi in batch_ind:
+
+ D0 = dist(Xt[bi], self.xt_)
+ idx = np.argmin(D0, axis=1)
+
+ # transport the target samples
+ transp_ = self.coupling_.T / np.sum(
+ self.coupling_, 0)[:, None]
+ transp_[~ np.isfinite(transp_)] = 0
+ transp_Xt_ = np.dot(transp_, self.xs_)
+
+ # define the transported points
+ transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :]
+
+ transp_Xt.append(transp_Xt_)
+
+ transp_Xt = np.concatenate(transp_Xt, axis=0)
+
+ return transp_Xt
+
+
+class SinkhornTransport(BaseTransport):
+ """Domain Adapatation OT method based on Sinkhorn Algorithm
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ max_iter : int, float, optional (default=1000)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ tol : float, optional (default=10e-9)
+ The precision required to stop the optimization algorithm.
+ mapping : string, optional (default="barycentric")
+ The kind of mapping to apply to transport samples from a domain into
+ another one.
+ if "barycentric" only the samples used to estimate the coupling can
+ be transported from a domain to another one.
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ distribution : string, optional (default="uniform")
+ The kind of distribution estimation to employ
+ verbose : int, optional (default=0)
+ Controls the verbosity of the optimization algorithm
+ log : int, optional (default=0)
+ Controls the logs of the optimization algorithm
+ limit_max: float, optional (defaul=np.infty)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
+ Transport, Advances in Neural Information Processing Systems (NIPS)
+ 26, 2013
+ """
+
+ def __init__(self, reg_e=1., max_iter=1000,
+ tol=10e-9, verbose=False, log=False,
+ metric="sqeuclidean",
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=np.infty):
+
+ self.reg_e = reg_e
+ self.max_iter = max_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.limit_max = limit_max
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_labeled_target_samples,)
+ The class labels
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
+
+ # coupling estimation
+ returned_ = sinkhorn(
+ a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
+ numItermax=self.max_iter, stopThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
+
+
+class EMDTransport(BaseTransport):
+ """Domain Adapatation OT method based on Earth Mover's Distance
+
+ Parameters
+ ----------
+ mapping : string, optional (default="barycentric")
+ The kind of mapping to apply to transport samples from a domain into
+ another one.
+ if "barycentric" only the samples used to estimate the coupling can
+ be transported from a domain to another one.
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ distribution : string, optional (default="uniform")
+ The kind of distribution estimation to employ
+ verbose : int, optional (default=0)
+ Controls the verbosity of the optimization algorithm
+ log : int, optional (default=0)
+ Controls the logs of the optimization algorithm
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+
+ References
+ ----------
+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
+ """
+
+ def __init__(self, metric="sqeuclidean",
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10):
+
+ self.metric = metric
+ self.limit_max = limit_max
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_labeled_target_samples,)
+ The class labels
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ super(EMDTransport, self).fit(Xs, ys, Xt, yt)
+
+ # coupling estimation
+ self.coupling_ = emd(
+ a=self.mu_s, b=self.mu_t, M=self.cost_,
+ )
+
+ return self
+
+
+class SinkhornLpl1Transport(BaseTransport):
+ """Domain Adapatation OT method based on sinkhorn algorithm +
+ LpL1 class regularization.
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_cl : float, optional (default=0.1)
+ Class regularization parameter
+ mapping : string, optional (default="barycentric")
+ The kind of mapping to apply to transport samples from a domain into
+ another one.
+ if "barycentric" only the samples used to estimate the coupling can
+ be transported from a domain to another one.
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ distribution : string, optional (default="uniform")
+ The kind of distribution estimation to employ
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ max_inner_iter : int, float, optional (default=200)
+ The number of iteration in the inner loop
+ verbose : int, optional (default=0)
+ Controls the verbosity of the optimization algorithm
+ limit_max: float, optional (defaul=np.infty)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+
+ References
+ ----------
+
+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
+
+ """
+
+ def __init__(self, reg_e=1., reg_cl=0.1,
+ max_iter=10, max_inner_iter=200,
+ tol=10e-9, verbose=False,
+ metric="sqeuclidean",
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=np.infty):
+
+ self.reg_e = reg_e
+ self.reg_cl = reg_cl
+ self.max_iter = max_iter
+ self.max_inner_iter = max_inner_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.metric = metric
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_labeled_target_samples,)
+ The class labels
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt, ys=ys):
+
+ super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
+
+ self.coupling_ = sinkhorn_lpl1_mm(
+ a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
+ verbose=self.verbose)
+
+ return self
+
+
+class SinkhornL1l2Transport(BaseTransport):
+ """Domain Adapatation OT method based on sinkhorn algorithm +
+ l1l2 class regularization.
+
+ Parameters
+ ----------
+ reg_e : float, optional (default=1)
+ Entropic regularization parameter
+ reg_cl : float, optional (default=0.1)
+ Class regularization parameter
+ mapping : string, optional (default="barycentric")
+ The kind of mapping to apply to transport samples from a domain into
+ another one.
+ if "barycentric" only the samples used to estimate the coupling can
+ be transported from a domain to another one.
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ distribution : string, optional (default="uniform")
+ The kind of distribution estimation to employ
+ max_iter : int, float, optional (default=10)
+ The minimum number of iteration before stopping the optimization
+ algorithm if no it has not converged
+ max_inner_iter : int, float, optional (default=200)
+ The number of iteration in the inner loop
+ verbose : int, optional (default=0)
+ Controls the verbosity of the optimization algorithm
+ log : int, optional (default=0)
+ Controls the logs of the optimization algorithm
+ limit_max: float, optional (default=10)
+ Controls the semi supervised mode. Transport between labeled source
+ and target samples of different classes will exhibit an infinite cost
+ (10 times the maximum value of the cost matrix)
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
+
+ """
+
+ def __init__(self, reg_e=1., reg_cl=0.1,
+ max_iter=10, max_inner_iter=200,
+ tol=10e-9, verbose=False, log=False,
+ metric="sqeuclidean",
+ distribution_estimation=distribution_estimation_uniform,
+ out_of_sample_map='ferradans', limit_max=10):
+
+ self.reg_e = reg_e
+ self.reg_cl = reg_cl
+ self.max_iter = max_iter
+ self.max_inner_iter = max_inner_iter
+ self.tol = tol
+ self.verbose = verbose
+ self.log = log
+ self.metric = metric
+ self.distribution_estimation = distribution_estimation
+ self.out_of_sample_map = out_of_sample_map
+ self.limit_max = limit_max
+
+ def fit(self, Xs, ys=None, Xt=None, yt=None):
+ """Build a coupling matrix from source and target sets of samples
+ (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_labeled_target_samples,)
+ The class labels
+
+ Returns
+ -------
+ self : object
+ Returns self.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt, ys=ys):
+
+ super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)
+
+ returned_ = sinkhorn_l1l2_gl(
+ a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.cost_,
+ reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
+ verbose=self.verbose, log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.log_ = returned_
+ else:
+ self.coupling_ = returned_
+ self.log_ = dict()
+
+ return self
+
+
+class MappingTransport(BaseEstimator):
+ """MappingTransport: DA methods that aims at jointly estimating a optimal
+ transport coupling and the associated mapping
+
+ Parameters
+ ----------
+ mu : float, optional (default=1)
+ Weight for the linear OT loss (>0)
+ eta : float, optional (default=0.001)
+ Regularization term for the linear mapping L (>0)
+ bias : bool, optional (default=False)
+ Estimate linear mapping with constant bias
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
+ kernel : string, optional (default="linear")
+ The kernel to use either linear or gaussian
+ sigma : float, optional (default=1)
+ The gaussian kernel parameter
+ max_iter : int, optional (default=100)
+ Max number of BCD iterations
+ tol : float, optional (default=1e-5)
+ Stop threshold on relative loss decrease (>0)
+ max_inner_iter : int, optional (default=10)
+ Max number of iterations (inner CG solver)
+ inner_tol : float, optional (default=1e-6)
+ Stop threshold on error (inner CG solver) (>0)
+ verbose : bool, optional (default=False)
+ Print information along iterations
+ log : bool, optional (default=False)
+ record log if True
+
+ Attributes
+ ----------
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
+ The optimal coupling
+ mapping_ : array-like, shape (n_features (+ 1), n_features)
+ (if bias) for kernel == linear
+ The associated mapping
+ array-like, shape (n_source_samples (+ 1), n_features)
+ (if bias) for kernel == gaussian
+ log_ : dictionary
+ The dictionary of log, empty dic if parameter log is not True
+
+ References
+ ----------
+
+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
+ "Mapping estimation for discrete optimal transport",
+ Neural Information Processing Systems (NIPS), 2016.
+
+ """
+
+ def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean",
+ kernel="linear", sigma=1, max_iter=100, tol=1e-5,
+ max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False,
+ verbose2=False):
+
+ self.metric = metric
+ self.mu = mu
+ self.eta = eta
+ self.bias = bias
+ self.kernel = kernel
+ self.sigma = sigma
+ self.max_iter = max_iter
+ self.tol = tol
+ self.max_inner_iter = max_inner_iter
+ self.inner_tol = inner_tol
+ self.log = log
+ self.verbose = verbose
+ self.verbose2 = verbose2
+
+ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
+ """Builds an optimal coupling and estimates the associated mapping
+ from source and target sets of samples (Xs, ys) and (Xt, yt)
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_labeled_target_samples,)
+ The class labels
+
+ Returns
+ -------
+ self : object
+ Returns self
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs, Xt=Xt):
+
+ self.xs_ = Xs
+ self.xt_ = Xt
+
+ if self.kernel == "linear":
+ returned_ = joint_OT_mapping_linear(
+ Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
+ verbose=self.verbose, verbose2=self.verbose2,
+ numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter, stopThr=self.tol,
+ stopInnerThr=self.inner_tol, log=self.log)
+
+ elif self.kernel == "gaussian":
+ returned_ = joint_OT_mapping_kernel(
+ Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
+ sigma=self.sigma, verbose=self.verbose,
+ verbose2=self.verbose, numItermax=self.max_iter,
+ numInnerItermax=self.max_inner_iter,
+ stopInnerThr=self.inner_tol, stopThr=self.tol,
+ log=self.log)
+
+ # deal with the value of log
+ if self.log:
+ self.coupling_, self.mapping_, self.log_ = returned_
+ else:
+ self.coupling_, self.mapping_ = returned_
+ self.log_ = dict()
+
+ return self
+
+ def transform(self, Xs):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+
+ Returns
+ -------
+ transp_Xs : array-like, shape (n_source_samples, n_features)
+ The transport source samples.
+ """
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ if np.array_equal(self.xs_, Xs):
+ # perform standard barycentric mapping
+ transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ # compute transported samples
+ transp_Xs = np.dot(transp, self.xt_)
+ else:
+ if self.kernel == "gaussian":
+ K = kernel(Xs, self.xs_, method=self.kernel,
+ sigma=self.sigma)
+ elif self.kernel == "linear":
+ K = Xs
+ if self.bias:
+ K = np.hstack((K, np.ones((Xs.shape[0], 1))))
+ transp_Xs = K.dot(self.mapping_)
+
+ return transp_Xs
diff --git a/ot/utils.py b/ot/utils.py
index 2b2f8b3..01f2a67 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -13,6 +13,9 @@ import time
import numpy as np
from scipy.spatial.distance import cdist
+import sys
+import warnings
+
__time_tic_toc = time.time()
@@ -163,3 +166,240 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()):
[p.join() for p in proc]
return [x for i, x in sorted(res)]
+
+
+def check_params(**kwargs):
+ """check_params: check whether some parameters are missing
+ """
+
+ missing_params = []
+ check = True
+
+ for param in kwargs:
+ if kwargs[param] is None:
+ missing_params.append(param)
+
+ if len(missing_params) > 0:
+ print("POT - Warning: following necessary parameters are missing")
+ for p in missing_params:
+ print("\n", p)
+
+ check = False
+
+ return check
+
+
+class deprecated(object):
+ """Decorator to mark a function or class as deprecated.
+
+ deprecated class from scikit-learn package
+ https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
+ Issue a warning when the function is called/the class is instantiated and
+ adds a warning to the docstring.
+ The optional extra argument will be appended to the deprecation message
+ and the docstring. Note: to use this with the default value for extra, put
+ in an empty of parentheses:
+ >>> from ot.deprecation import deprecated
+ >>> @deprecated()
+ ... def some_function(): pass
+
+ Parameters
+ ----------
+ extra : string
+ to be added to the deprecation messages
+ """
+
+ # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
+ # but with many changes.
+
+ def __init__(self, extra=''):
+ self.extra = extra
+
+ def __call__(self, obj):
+ """Call method
+ Parameters
+ ----------
+ obj : object
+ """
+ if isinstance(obj, type):
+ return self._decorate_class(obj)
+ else:
+ return self._decorate_fun(obj)
+
+ def _decorate_class(self, cls):
+ msg = "Class %s is deprecated" % cls.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ # FIXME: we should probably reset __new__ for full generality
+ init = cls.__init__
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return init(*args, **kwargs)
+
+ cls.__init__ = wrapped
+
+ wrapped.__name__ = '__init__'
+ wrapped.__doc__ = self._update_doc(init.__doc__)
+ wrapped.deprecated_original = init
+
+ return cls
+
+ def _decorate_fun(self, fun):
+ """Decorate function fun"""
+
+ msg = "Function %s is deprecated" % fun.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return fun(*args, **kwargs)
+
+ wrapped.__name__ = fun.__name__
+ wrapped.__dict__ = fun.__dict__
+ wrapped.__doc__ = self._update_doc(fun.__doc__)
+
+ return wrapped
+
+ def _update_doc(self, olddoc):
+ newdoc = "DEPRECATED"
+ if self.extra:
+ newdoc = "%s: %s" % (newdoc, self.extra)
+ if olddoc:
+ newdoc = "%s\n\n%s" % (newdoc, olddoc)
+ return newdoc
+
+
+def _is_deprecated(func):
+ """Helper to check if func is wraped by our deprecated decorator"""
+ if sys.version_info < (3, 5):
+ raise NotImplementedError("This is only available for python3.5 "
+ "or above")
+ closures = getattr(func, '__closure__', [])
+ if closures is None:
+ closures = []
+ is_deprecated = ('deprecated' in ''.join([c.cell_contents
+ for c in closures
+ if isinstance(c.cell_contents, str)]))
+ return is_deprecated
+
+
+class BaseEstimator(object):
+ """Base class for most objects in POT
+ adapted from sklearn BaseEstimator class
+
+ Notes
+ -----
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
+ """
+
+ @classmethod
+ def _get_param_names(cls):
+ """Get parameter names for the estimator"""
+ try:
+ from inspect import signature
+ except ImportError:
+ from .externals.funcsigs import signature
+ # fetch the constructor or the original constructor before
+ # deprecation wrapping if any
+ init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
+ if init is object.__init__:
+ # No explicit constructor to introspect
+ return []
+
+ # introspect the constructor arguments to find the model parameters
+ # to represent
+ init_signature = signature(init)
+ # Consider the constructor parameters excluding 'self'
+ parameters = [p for p in init_signature.parameters.values()
+ if p.name != 'self' and p.kind != p.VAR_KEYWORD]
+ for p in parameters:
+ if p.kind == p.VAR_POSITIONAL:
+ raise RuntimeError("POT estimators should always "
+ "specify their parameters in the signature"
+ " of their __init__ (no varargs)."
+ " %s with constructor %s doesn't "
+ " follow this convention."
+ % (cls, init_signature))
+ # Extract and sort argument names excluding 'self'
+ return sorted([p.name for p in parameters])
+
+ def get_params(self, deep=True):
+ """Get parameters for this estimator.
+
+ Parameters
+ ----------
+ deep : boolean, optional
+ If True, will return the parameters for this estimator and
+ contained subobjects that are estimators.
+
+ Returns
+ -------
+ params : mapping of string to any
+ Parameter names mapped to their values.
+ """
+ out = dict()
+ for key in self._get_param_names():
+ # We need deprecation warnings to always be on in order to
+ # catch deprecated param values.
+ # This is set in utils/__init__.py but it gets overwritten
+ # when running under python3 somehow.
+ warnings.simplefilter("always", DeprecationWarning)
+ try:
+ with warnings.catch_warnings(record=True) as w:
+ value = getattr(self, key, None)
+ if len(w) and w[0].category == DeprecationWarning:
+ # if the parameter is deprecated, don't show it
+ continue
+ finally:
+ warnings.filters.pop(0)
+
+ # XXX: should we rather test if instance of estimator?
+ if deep and hasattr(value, 'get_params'):
+ deep_items = value.get_params().items()
+ out.update((key + '__' + k, val) for k, val in deep_items)
+ out[key] = value
+ return out
+
+ def set_params(self, **params):
+ """Set the parameters of this estimator.
+
+ The method works on simple estimators as well as on nested objects
+ (such as pipelines). The latter have parameters of the form
+ ``<component>__<parameter>`` so that it's possible to update each
+ component of a nested object.
+
+ Returns
+ -------
+ self
+ """
+ if not params:
+ # Simple optimisation to gain speed (inspect is slow)
+ return self
+ valid_params = self.get_params(deep=True)
+ # for key, value in iteritems(params):
+ for key, value in params.items():
+ split = key.split('__', 1)
+ if len(split) > 1:
+ # nested objects case
+ name, sub_name = split
+ if name not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (name, self))
+ sub_object = valid_params[name]
+ sub_object.set_params(**{sub_name: value})
+ else:
+ # simple objects case
+ if key not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (key, self.__class__.__name__))
+ setattr(self, key, value)
+ return self