summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-03-31 17:12:28 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-03-31 17:12:28 +0200
commit6aa0f1f4e275098948d4b312530119e5d95b8884 (patch)
tree58c78624e5e687f2453bba4df77331befa7da999 /ot/da.py
parent171b962cea369aee2513884a1fb3dca8920b77cd (diff)
v1 jcpot example test
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py263
1 files changed, 157 insertions, 106 deletions
diff --git a/ot/da.py b/ot/da.py
index fd5da4b..a3da8c1 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -748,79 +748,58 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
stopThr=1e-6, verbose=False, log=False, **kwargs):
- """Joint OT and proportion estimation as proposed in [27]
+ r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
The function solves the following optimization problem:
.. math::
- \mathbf{h} = \argmin_{\mathbf{h} \in \Delta_C}\quad \sum_{k=1}^K \lambda_k
- W_{reg}\left((\mathbf{D}_2^{(k)} \mathbf{h})^T \mathbf{\delta}_{\mathbf{X}^{(k)}}, \mu\right)
+ \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
+ W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
+ s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
- s.t. \gamma^T_k \mathbf{1}_n = \mathbf{1}_n/n
-
- \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
-
- \gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in
- Xs and Xt (scaled by ns)
- - :math:`L` is a ns x d linear operator on a kernel matrix that
- approximates the barycentric mapping
- - a and b are uniform source and target weights
-
- The problem consist in solving jointly an optimal transport matrix
- :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping
- :math:`n_s\gamma X_t`.
+ - :math:`\lambda_k` is the weight of k-th source domain
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
+ - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
+ - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
- One can also estimate a mapping with constant bias (see supplementary
- material of [8]) using the bias optional argument.
-
- The algorithm used for solving the problem is the block coordinate
- descent that alternates between updates of G (using conditional gradient)
- and the update of L using a classical kernel least square solver.
+ The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
+ The algorithm used for solving the problem is the Iterative Bregman projections algorithm
+ with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution.
Parameters
----------
- xs : np.ndarray (ns,d)
- samples in the source domain
- xt : np.ndarray (nt,d)
+ Xs : list of K np.ndarray(nsk,d)
+ features of all source domains' samples
+ Ys : list of K np.ndarray(nsk,)
+ labels of all source domains' samples
+ Xt : np.ndarray (nt,d)
samples in the target domain
- mu : float,optional
- Weight for the linear OT loss (>0)
- eta : float, optional
- Regularization term for the linear mapping L (>0)
- kerneltype : str,optional
- kernel used by calling function ot.utils.kernel (gaussian by default)
- sigma : float, optional
- Gaussian kernel bandwidth.
- bias : bool,optional
- Estimate linear mapping with constant bias
- verbose : bool, optional
- Print information along iterations
- verbose2 : bool, optional
- Print information along iterations
+ reg : float
+ Regularization term > 0
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
numItermax : int, optional
- Max number of BCD iterations
- numInnerItermax : int, optional
- Max number of iterations (inner CG solver)
- stopInnerThr : float, optional
- Stop threshold on error (inner CG solver) (>0)
+ Max number of iterations
stopThr : float, optional
- Stop threshold on relative loss decrease (>0)
+ Stop threshold on relative change in the barycenter (>0)
log : bool, optional
record log if True
-
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- L : (ns x d) ndarray
- Nonlinear mapping matrix (ns+1 x d if bias)
+ gamma : List of K (nsk x nt) ndarrays
+ Optimal transportation matrices for the given parameters for each pair of source and target domains
+ h : (C,) ndarray
+ proportion estimation in the target domain
log : dict
log dictionary return only if log==True in parameters
@@ -828,62 +807,59 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
References
----------
- .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
- "Mapping estimation for discrete optimal transport",
- Neural Information Processing Systems (NIPS), 2016.
-
- See Also
- --------
- ot.lp.emd : Unregularized OT
- ot.optim.cg : General regularized OT
+ .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
- """
+ '''
nbclasses = len(np.unique(Ys[0]))
nbdomains = len(Xs)
- # we then build, for each source domain, specific information
+ # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
all_domains = []
+
+ # log dictionary
+ if log:
+ log = {'niter': 0, 'err': [], 'all_domains': []}
+
for d in range(nbdomains):
dom = {}
- # get number of elements for this domain
- nb_elem = Xs[d].shape[0]
- dom['nbelem'] = nb_elem
- classes = np.unique(Ys[d])
+ nsk = Xs[d].shape[0] # get number of elements for this domain
+ dom['nbelem'] = nsk
+ classes = np.unique(Ys[d]) # get number of classes for this domain
+ # format classes to start from 0 for convenience
if np.min(classes) != 0:
Ys[d] = Ys[d] - np.min(classes)
classes = np.unique(Ys[d])
- # build the corresponding D matrix
- D1 = np.zeros((nbclasses, nb_elem))
- D2 = np.zeros((nbclasses, nb_elem))
- classes_d = np.zeros(nbclasses)
-
- classes_d[np.unique(Ys[d]).astype(int)] = 1
- dom['classes'] = classes_d
+ # build the corresponding D_1 and D_2 matrices
+ D1 = np.zeros((nbclasses, nsk))
+ D2 = np.zeros((nbclasses, nsk))
for c in classes:
nbelemperclass = np.sum(Ys[d] == c)
if nbelemperclass != 0:
D1[int(c), Ys[d] == c] = 1.
- D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) # *nbclasses_d)
+ D2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
dom['D1'] = D1
dom['D2'] = D2
- # build the distance matrix
+ # build the cost matrix and the Gibbs kernel
M = dist(Xs[d], Xt, metric=metric)
M = M / np.median(M)
- dom['K'] = np.exp(-M/reg)
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+ dom['K'] = K
all_domains.append(dom)
- distribT = unif(np.shape(Xt)[0])
-
- if log:
- log = {'niter': 0, 'err': []}
+ # uniform target distribution
+ a = unif(np.shape(Xt)[0])
- cpt = 0
+ cpt = 0 # iterations count
err = 1
old_bary = np.ones((nbclasses))
@@ -891,13 +867,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
bary = np.zeros((nbclasses))
+ # update coupling matrices for marginal constraints w.r.t. uniform target distribution
for d in range(nbdomains):
- all_domains[d]['K'] = projC(all_domains[d]['K'], distribT)
+ all_domains[d]['K'] = projC(all_domains[d]['K'], a)
other = np.sum(all_domains[d]['K'], axis=1)
bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains
bary = np.exp(bary)
+ # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
for d in range(nbdomains):
new = np.dot(all_domains[d]['D2'].T, bary)
all_domains[d]['K'] = projR(all_domains[d]['K'], new)
@@ -915,12 +893,14 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
print('{:5d}|{:8e}|'.format(cpt, err))
bary = bary / np.sum(bary)
+ couplings = [all_domains[d]['K'] for d in range(nbdomains)]
if log:
log['niter'] = cpt
- return bary, log
+ log['all_domains'] = all_domains
+ return couplings, bary, log
else:
- return bary
+ return couplings, bary
def distribution_estimation_uniform(X):
@@ -2093,9 +2073,10 @@ class UnbalancedSinkhornTransport(BaseTransport):
return self
+
class JCPOTTransport(BaseTransport):
- """Domain Adapatation OT method for target shift based on sinkhorn algorithm.
+ """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm.
Parameters
----------
@@ -2104,8 +2085,6 @@ class JCPOTTransport(BaseTransport):
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
algorithm if no it has not converged
- max_inner_iter : int, float, optional (default=200)
- The number of iteration in the inner loop
tol : float, optional (default=10e-9)
Stop threshold on error (inner sinkhorn solver) (>0)
verbose : bool, optional (default=False)
@@ -2126,21 +2105,20 @@ class JCPOTTransport(BaseTransport):
Attributes
----------
- coupling_ : array-like, shape (n_source_samples, n_target_samples)
- The optimal coupling
+ coupling_ : list of array-like objects, shape K x (n_source_samples, n_target_samples)
+ A set of optimal couplings between each source domain and the target domain
+ proportions_ : array-like, shape (n_classes,)
+ Estimated class proportions in the target domain
log_ : dictionary
The dictionary of log, empty dic if parameter log is not True
References
----------
- .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
- "Optimal Transport for Domain Adaptation," in IEEE
- Transactions on Pattern Analysis and Machine Intelligence ,
- vol.PP, no.99, pp.1-1
- .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
- Generalized conditional gradient: analysis of convergence
- and applications. arXiv preprint arXiv:1510.06567.
+ .. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS),
+ vol. 89, p.849-858, 2019.
"""
@@ -2156,20 +2134,18 @@ class JCPOTTransport(BaseTransport):
self.verbose = verbose
self.log = log
self.metric = metric
- self.norm = norm
- self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ """Building coupling matrices from a list of source and target sets of samples
(Xs, ys) and (Xt, yt)
Parameters
----------
- Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
- ys : array-like, shape (n_source_samples,)
- The class labels
+ Xs : list of K array-like objects, shape K x (nk_source_samples, n_features)
+ A list of the training input samples.
+ ys : list of K array-like objects, shape K x (nk_source_samples,)
+ A list of the class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -2188,15 +2164,90 @@ class JCPOTTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt, ys=ys):
- returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg = self.reg_e,
- metric=self.metric, numItermax=self.max_iter, stopThr=self.tol,
+ self.xs_ = Xs
+ self.xt_ = Xt
+
+ returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e,
+ metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
# deal with the value of log
if self.log:
- self.coupling_, self.log_ = returned_
+ self.coupling_, self.proportions_, self.log_ = returned_
else:
- self.coupling_ = returned_
+ self.coupling_, self.proportions_ = returned_
self.log_ = dict()
return self
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+ """
+
+ transp_Xs = []
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]):
+
+ # perform standard barycentric mapping for each source domain
+
+ for coupling in self.coupling_:
+ transp = coupling / np.sum(coupling, 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ # compute transported samples
+ transp_Xs.append(np.dot(transp, self.xt_))
+ else:
+
+ # perform out of sample mapping
+ indices = np.arange(Xs.shape[0])
+ batch_ind = [
+ indices[i:i + batch_size]
+ for i in range(0, len(indices), batch_size)]
+
+ transp_Xs = []
+
+ for bi in batch_ind:
+ transp_Xs_ = []
+
+ # get the nearest neighbor in the sources domains
+ xs = np.concatenate(self.xs_, axis=0)
+ idx = np.argmin(dist(Xs[bi], xs), axis=1)
+
+ # transport the source samples
+ for coupling in self.coupling_:
+ transp = coupling / np.sum(
+ coupling, 1)[:, None]
+ transp[~ np.isfinite(transp)] = 0
+ transp_Xs_.append(np.dot(transp, self.xt_))
+
+ transp_Xs_ = np.concatenate(transp_Xs_, axis=0)
+
+ # define the transported points
+ transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :]
+ transp_Xs.append(transp_Xs_)
+
+ transp_Xs = np.concatenate(transp_Xs, axis=0)
+
+ return transp_Xs