summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-03-31 17:12:28 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-03-31 17:12:28 +0200
commit6aa0f1f4e275098948d4b312530119e5d95b8884 (patch)
tree58c78624e5e687f2453bba4df77331befa7da999
parent171b962cea369aee2513884a1fb3dca8920b77cd (diff)
v1 jcpot example test
-rw-r--r--examples/plot_otda_jcpot.py185
-rw-r--r--ot/da.py263
-rw-r--r--test/test_da.py63
3 files changed, 404 insertions, 107 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py
new file mode 100644
index 0000000..5e5fff8
--- /dev/null
+++ b/examples/plot_otda_jcpot.py
@@ -0,0 +1,185 @@
+# -*- coding: utf-8 -*-
+"""
+========================
+OT for multi-source target shift
+========================
+
+This example introduces a target shift problem with two 2D source and 1 target domain.
+
+"""
+
+# Authors: Remi Flamary <remi.flamary@unice.fr>
+# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
+#
+# License: MIT License
+
+import pylab as pl
+import numpy as np
+import ot
+
+##############################################################################
+# Generate data
+# -------------
+n = 50
+sigma = 0.3
+np.random.seed(1985)
+
+
+def get_data(n, p, dec):
+ y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n))))
+ x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + sigma * np.random.randn(len(y), 2)
+
+ x[:, 0] += dec[0]
+ x[:, 1] += dec[1]
+
+ return x, y
+
+
+p1 = .2
+dec1 = [0, 2]
+
+p2 = .9
+dec2 = [0, -2]
+
+pt = .4
+dect = [4, 0]
+
+xs1, ys1 = get_data(n, p1, dec1)
+xs2, ys2 = get_data(n + 1, p2, dec2)
+xt, yt = get_data(n, pt, dect)
+all_Xr = [xs1, xs2]
+all_Yr = [ys1, ys2]
+# %%
+da = 1.5
+
+
+def plot_ax(dec, name):
+ pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)
+ pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)
+ pl.text(dec[0] - .5, dec[1] + 2, name)
+
+
+##############################################################################
+# Fig 1 : plots source and target samples
+# ---------------------------------------
+
+pl.figure(1)
+pl.clf()
+plot_ax(dec1, 'Source 1')
+plot_ax(dec2, 'Source 2')
+plot_ax(dect, 'Target')
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, label='Source 1 (0.8,0.2)')
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, label='Source 2 (0.1,0.9)')
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, label='Target (0.6,0.4)')
+pl.title('Data')
+
+pl.legend()
+pl.axis('equal')
+pl.axis('off')
+
+
+##############################################################################
+# Instantiate Sinkhorn transport algorithm and fit them for all source domains
+# ----------------------------------------------------------------------------
+ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-2, metric='euclidean')
+
+M1 = ot.dist(xs1, xt, 'euclidean')
+M2 = ot.dist(xs2, xt, 'euclidean')
+
+
+def print_G(G, xs, ys, xt):
+ for i in range(G.shape[0]):
+ for j in range(G.shape[1]):
+ if G[i, j] > 5e-4:
+ if ys[i]:
+ c = 'b'
+ else:
+ c = 'r'
+ pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2)
+
+
+##############################################################################
+# Fig 2 : plot optimal couplings and transported samples
+# ------------------------------------------------------
+pl.figure(2)
+pl.clf()
+plot_ax(dec1, 'Source 1')
+plot_ax(dec2, 'Source 2')
+plot_ax(dect, 'Target')
+print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt)
+print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt)
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
+
+pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
+pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
+
+pl.title('Independent OT')
+
+pl.legend()
+pl.axis('equal')
+pl.axis('off')
+
+
+##############################################################################
+# Instantiate JCPOT adaptation algorithm and fit it
+# ----------------------------------------------------------------------------
+otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, tol=1e-9, verbose=True, log=True)
+otda.fit(all_Xr, all_Yr, xt)
+
+ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2'])
+ws2 = otda.proportions_.dot(otda.log_['all_domains'][1]['D2'])
+
+pl.figure(3)
+pl.clf()
+plot_ax(dec1, 'Source 1')
+plot_ax(dec2, 'Source 2')
+plot_ax(dect, 'Target')
+print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt)
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
+
+pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
+pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
+
+pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1]))
+
+pl.legend()
+pl.axis('equal')
+pl.axis('off')
+
+##############################################################################
+# Run oracle transport algorithm with known proportions
+# ----------------------------------------------------------------------------
+
+otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log=True)
+otda.fit(all_Xr, all_Yr, xt)
+
+h_res = np.array([1 - pt, pt])
+
+ws1 = h_res.dot(otda.log_['all_domains'][0]['D2'])
+ws2 = h_res.dot(otda.log_['all_domains'][1]['D2'])
+
+pl.figure(4)
+pl.clf()
+plot_ax(dec1, 'Source 1')
+plot_ax(dec2, 'Source 2')
+plot_ax(dect, 'Target')
+print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt)
+print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt)
+pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
+pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
+pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
+
+pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
+pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
+
+pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1]))
+
+pl.legend()
+pl.axis('equal')
+pl.axis('off')
+pl.show()
diff --git a/ot/da.py b/ot/da.py
index fd5da4b..a3da8c1 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -748,79 +748,58 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
stopThr=1e-6, verbose=False, log=False, **kwargs):
- """Joint OT and proportion estimation as proposed in [27]
+ r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
The function solves the following optimization problem:
.. math::
- \mathbf{h} = \argmin_{\mathbf{h} \in \Delta_C}\quad \sum_{k=1}^K \lambda_k
- W_{reg}\left((\mathbf{D}_2^{(k)} \mathbf{h})^T \mathbf{\delta}_{\mathbf{X}^{(k)}}, \mu\right)
+ \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
+ W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
+ s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
- s.t. \gamma^T_k \mathbf{1}_n = \mathbf{1}_n/n
-
- \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
-
- \gamma\geq 0
where :
- - M is the (ns,nt) squared euclidean cost matrix between samples in
- Xs and Xt (scaled by ns)
- - :math:`L` is a ns x d linear operator on a kernel matrix that
- approximates the barycentric mapping
- - a and b are uniform source and target weights
-
- The problem consist in solving jointly an optimal transport matrix
- :math:`\gamma` and the nonlinear mapping that fits the barycentric mapping
- :math:`n_s\gamma X_t`.
+ - :math:`\lambda_k` is the weight of k-th source domain
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
+ - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
+ - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
+ - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
- One can also estimate a mapping with constant bias (see supplementary
- material of [8]) using the bias optional argument.
-
- The algorithm used for solving the problem is the block coordinate
- descent that alternates between updates of G (using conditional gradient)
- and the update of L using a classical kernel least square solver.
+ The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
+ The algorithm used for solving the problem is the Iterative Bregman projections algorithm
+ with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution.
Parameters
----------
- xs : np.ndarray (ns,d)
- samples in the source domain
- xt : np.ndarray (nt,d)
+ Xs : list of K np.ndarray(nsk,d)
+ features of all source domains' samples
+ Ys : list of K np.ndarray(nsk,)
+ labels of all source domains' samples
+ Xt : np.ndarray (nt,d)
samples in the target domain
- mu : float,optional
- Weight for the linear OT loss (>0)
- eta : float, optional
- Regularization term for the linear mapping L (>0)
- kerneltype : str,optional
- kernel used by calling function ot.utils.kernel (gaussian by default)
- sigma : float, optional
- Gaussian kernel bandwidth.
- bias : bool,optional
- Estimate linear mapping with constant bias
- verbose : bool, optional
- Print information along iterations
- verbose2 : bool, optional
- Print information along iterations
+ reg : float
+ Regularization term > 0
+ metric : string, optional (default="sqeuclidean")
+ The ground metric for the Wasserstein problem
numItermax : int, optional
- Max number of BCD iterations
- numInnerItermax : int, optional
- Max number of iterations (inner CG solver)
- stopInnerThr : float, optional
- Stop threshold on error (inner CG solver) (>0)
+ Max number of iterations
stopThr : float, optional
- Stop threshold on relative loss decrease (>0)
+ Stop threshold on relative change in the barycenter (>0)
log : bool, optional
record log if True
-
+ verbose : bool, optional (default=False)
+ Controls the verbosity of the optimization algorithm
Returns
-------
- gamma : (ns x nt) ndarray
- Optimal transportation matrix for the given parameters
- L : (ns x d) ndarray
- Nonlinear mapping matrix (ns+1 x d if bias)
+ gamma : List of K (nsk x nt) ndarrays
+ Optimal transportation matrices for the given parameters for each pair of source and target domains
+ h : (C,) ndarray
+ proportion estimation in the target domain
log : dict
log dictionary return only if log==True in parameters
@@ -828,62 +807,59 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
References
----------
- .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
- "Mapping estimation for discrete optimal transport",
- Neural Information Processing Systems (NIPS), 2016.
-
- See Also
- --------
- ot.lp.emd : Unregularized OT
- ot.optim.cg : General regularized OT
+ .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
- """
+ '''
nbclasses = len(np.unique(Ys[0]))
nbdomains = len(Xs)
- # we then build, for each source domain, specific information
+ # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
all_domains = []
+
+ # log dictionary
+ if log:
+ log = {'niter': 0, 'err': [], 'all_domains': []}
+
for d in range(nbdomains):
dom = {}
- # get number of elements for this domain
- nb_elem = Xs[d].shape[0]
- dom['nbelem'] = nb_elem
- classes = np.unique(Ys[d])
+ nsk = Xs[d].shape[0] # get number of elements for this domain
+ dom['nbelem'] = nsk
+ classes = np.unique(Ys[d]) # get number of classes for this domain
+ # format classes to start from 0 for convenience
if np.min(classes) != 0:
Ys[d] = Ys[d] - np.min(classes)
classes = np.unique(Ys[d])
- # build the corresponding D matrix
- D1 = np.zeros((nbclasses, nb_elem))
- D2 = np.zeros((nbclasses, nb_elem))
- classes_d = np.zeros(nbclasses)
-
- classes_d[np.unique(Ys[d]).astype(int)] = 1
- dom['classes'] = classes_d
+ # build the corresponding D_1 and D_2 matrices
+ D1 = np.zeros((nbclasses, nsk))
+ D2 = np.zeros((nbclasses, nsk))
for c in classes:
nbelemperclass = np.sum(Ys[d] == c)
if nbelemperclass != 0:
D1[int(c), Ys[d] == c] = 1.
- D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) # *nbclasses_d)
+ D2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
dom['D1'] = D1
dom['D2'] = D2
- # build the distance matrix
+ # build the cost matrix and the Gibbs kernel
M = dist(Xs[d], Xt, metric=metric)
M = M / np.median(M)
- dom['K'] = np.exp(-M/reg)
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+ dom['K'] = K
all_domains.append(dom)
- distribT = unif(np.shape(Xt)[0])
-
- if log:
- log = {'niter': 0, 'err': []}
+ # uniform target distribution
+ a = unif(np.shape(Xt)[0])
- cpt = 0
+ cpt = 0 # iterations count
err = 1
old_bary = np.ones((nbclasses))
@@ -891,13 +867,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
bary = np.zeros((nbclasses))
+ # update coupling matrices for marginal constraints w.r.t. uniform target distribution
for d in range(nbdomains):
- all_domains[d]['K'] = projC(all_domains[d]['K'], distribT)
+ all_domains[d]['K'] = projC(all_domains[d]['K'], a)
other = np.sum(all_domains[d]['K'], axis=1)
bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains
bary = np.exp(bary)
+ # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
for d in range(nbdomains):
new = np.dot(all_domains[d]['D2'].T, bary)
all_domains[d]['K'] = projR(all_domains[d]['K'], new)
@@ -915,12 +893,14 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
print('{:5d}|{:8e}|'.format(cpt, err))
bary = bary / np.sum(bary)
+ couplings = [all_domains[d]['K'] for d in range(nbdomains)]
if log:
log['niter'] = cpt
- return bary, log
+ log['all_domains'] = all_domains
+ return couplings, bary, log
else:
- return bary
+ return couplings, bary
def distribution_estimation_uniform(X):
@@ -2093,9 +2073,10 @@ class UnbalancedSinkhornTransport(BaseTransport):
return self
+
class JCPOTTransport(BaseTransport):
- """Domain Adapatation OT method for target shift based on sinkhorn algorithm.
+ """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm.
Parameters
----------
@@ -2104,8 +2085,6 @@ class JCPOTTransport(BaseTransport):
max_iter : int, float, optional (default=10)
The minimum number of iteration before stopping the optimization
algorithm if no it has not converged
- max_inner_iter : int, float, optional (default=200)
- The number of iteration in the inner loop
tol : float, optional (default=10e-9)
Stop threshold on error (inner sinkhorn solver) (>0)
verbose : bool, optional (default=False)
@@ -2126,21 +2105,20 @@ class JCPOTTransport(BaseTransport):
Attributes
----------
- coupling_ : array-like, shape (n_source_samples, n_target_samples)
- The optimal coupling
+ coupling_ : list of array-like objects, shape K x (n_source_samples, n_target_samples)
+ A set of optimal couplings between each source domain and the target domain
+ proportions_ : array-like, shape (n_classes,)
+ Estimated class proportions in the target domain
log_ : dictionary
The dictionary of log, empty dic if parameter log is not True
References
----------
- .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
- "Optimal Transport for Domain Adaptation," in IEEE
- Transactions on Pattern Analysis and Machine Intelligence ,
- vol.PP, no.99, pp.1-1
- .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
- Generalized conditional gradient: analysis of convergence
- and applications. arXiv preprint arXiv:1510.06567.
+ .. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
+ "Optimal transport for multi-source domain adaptation under target shift",
+ International Conference on Artificial Intelligence and Statistics (AISTATS),
+ vol. 89, p.849-858, 2019.
"""
@@ -2156,20 +2134,18 @@ class JCPOTTransport(BaseTransport):
self.verbose = verbose
self.log = log
self.metric = metric
- self.norm = norm
- self.distribution_estimation = distribution_estimation
self.out_of_sample_map = out_of_sample_map
def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Build a coupling matrix from source and target sets of samples
+ """Building coupling matrices from a list of source and target sets of samples
(Xs, ys) and (Xt, yt)
Parameters
----------
- Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
- ys : array-like, shape (n_source_samples,)
- The class labels
+ Xs : list of K array-like objects, shape K x (nk_source_samples, n_features)
+ A list of the training input samples.
+ ys : list of K array-like objects, shape K x (nk_source_samples,)
+ A list of the class labels
Xt : array-like, shape (n_target_samples, n_features)
The training input samples.
yt : array-like, shape (n_target_samples,)
@@ -2188,15 +2164,90 @@ class JCPOTTransport(BaseTransport):
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt, ys=ys):
- returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg = self.reg_e,
- metric=self.metric, numItermax=self.max_iter, stopThr=self.tol,
+ self.xs_ = Xs
+ self.xt_ = Xt
+
+ returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e,
+ metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
# deal with the value of log
if self.log:
- self.coupling_, self.log_ = returned_
+ self.coupling_, self.proportions_, self.log_ = returned_
else:
- self.coupling_ = returned_
+ self.coupling_, self.proportions_ = returned_
self.log_ = dict()
return self
+
+ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
+ """Transports source samples Xs onto target ones Xt
+
+ Parameters
+ ----------
+ Xs : array-like, shape (n_source_samples, n_features)
+ The training input samples.
+ ys : array-like, shape (n_source_samples,)
+ The class labels
+ Xt : array-like, shape (n_target_samples, n_features)
+ The training input samples.
+ yt : array-like, shape (n_target_samples,)
+ The class labels. If some target samples are unlabeled, fill the
+ yt's elements with -1.
+
+ Warning: Note that, due to this convention -1 cannot be used as a
+ class label
+ batch_size : int, optional (default=128)
+ The batch size for out of sample inverse transform
+ """
+
+ transp_Xs = []
+
+ # check the necessary inputs parameters are here
+ if check_params(Xs=Xs):
+
+ if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]):
+
+ # perform standard barycentric mapping for each source domain
+
+ for coupling in self.coupling_:
+ transp = coupling / np.sum(coupling, 1)[:, None]
+
+ # set nans to 0
+ transp[~ np.isfinite(transp)] = 0
+
+ # compute transported samples
+ transp_Xs.append(np.dot(transp, self.xt_))
+ else:
+
+ # perform out of sample mapping
+ indices = np.arange(Xs.shape[0])
+ batch_ind = [
+ indices[i:i + batch_size]
+ for i in range(0, len(indices), batch_size)]
+
+ transp_Xs = []
+
+ for bi in batch_ind:
+ transp_Xs_ = []
+
+ # get the nearest neighbor in the sources domains
+ xs = np.concatenate(self.xs_, axis=0)
+ idx = np.argmin(dist(Xs[bi], xs), axis=1)
+
+ # transport the source samples
+ for coupling in self.coupling_:
+ transp = coupling / np.sum(
+ coupling, 1)[:, None]
+ transp[~ np.isfinite(transp)] = 0
+ transp_Xs_.append(np.dot(transp, self.xt_))
+
+ transp_Xs_ = np.concatenate(transp_Xs_, axis=0)
+
+ # define the transported points
+ transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :]
+ transp_Xs.append(transp_Xs_)
+
+ transp_Xs = np.concatenate(transp_Xs, axis=0)
+
+ return transp_Xs
diff --git a/test/test_da.py b/test/test_da.py
index 2a5e50e..a8c258a 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -5,7 +5,7 @@
# License: MIT License
import numpy as np
-from numpy.testing.utils import assert_allclose, assert_equal
+from numpy.testing import assert_allclose, assert_equal
import ot
from ot.datasets import make_data_classif
@@ -549,3 +549,64 @@ def test_linear_mapping_class():
Cst = np.cov(Xst.T)
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
+
+
+def test_jcpot_transport_class():
+ """test_jcpot_transport
+ """
+
+ ns1 = 150
+ ns2 = 150
+ nt = 200
+
+ Xs1, ys1 = make_data_classif('3gauss', ns1)
+ Xs2, ys2 = make_data_classif('3gauss', ns2)
+
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ Xs = [Xs1, Xs2]
+ ys = [ys1, ys2]
+
+ otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True)
+
+ # test its computed
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+ print(otda.proportions_)
+
+ assert hasattr(otda, "coupling_")
+ assert hasattr(otda, "proportions_")
+ assert hasattr(otda, "log_")
+
+ # test dimensions of coupling
+ for i, xs in enumerate(Xs):
+ assert_equal(otda.coupling_[i].shape, ((xs.shape[0], Xt.shape[0])))
+
+ # test all margin constraints
+ mu_t = unif(nt)
+
+ for i in range(len(Xs)):
+ # test margin constraints w.r.t. uniform target weights for each coupling matrix
+ assert_allclose(
+ np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3)
+
+ # test margin constraints w.r.t. modified source weights for each source domain
+
+ D1 = np.zeros((len(np.unique(ys[i])), len(ys[i])))
+ for c in np.unique(ys[i]):
+ nbelemperclass = np.sum(ys[i] == c)
+ if nbelemperclass != 0:
+ D1[int(c), ys[i] == c] = 1.
+
+ assert_allclose(
+ np.dot(D1, np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3)
+
+ # test transform
+ transp_Xs = otda.transform(Xs=Xs)
+ [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
+ #assert_equal(transp_Xs.shape, Xs.shape)
+
+ Xs_new, _ = make_data_classif('3gauss', ns1 + 1)
+ transp_Xs_new = otda.transform(Xs_new)
+
+ # check that the oos method is working
+ assert_equal(transp_Xs_new.shape, Xs_new.shape)