diff options
-rw-r--r-- | .travis.yml | 1 | ||||
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | examples/plot_otda_classes.py | 1 | ||||
-rw-r--r-- | examples/plot_otda_jcpot.py | 171 | ||||
-rw-r--r-- | ot/bregman.py | 224 | ||||
-rw-r--r-- | ot/da.py | 373 | ||||
-rw-r--r-- | ot/datasets.py | 17 | ||||
-rw-r--r-- | ot/lp/__init__.py | 2 | ||||
-rw-r--r-- | ot/plot.py | 3 | ||||
-rw-r--r-- | ot/utils.py | 22 | ||||
-rw-r--r-- | test/test_da.py | 146 |
11 files changed, 908 insertions, 55 deletions
diff --git a/.travis.yml b/.travis.yml index 5b3a26e..072bc55 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,6 +35,7 @@ install: - pip install -r requirements.txt - pip install -U "numpy>=1.14" "scipy<1.3" # for numpy array formatting in doctests + scipy version: otherwise, pymanopt fails, cf <https://github.com/pymanopt/pymanopt/issues/77> - pip install flake8 pytest "pytest-cov<2.6" + - pip install -U "sklearn" - pip install . # command to run tests + check syntax style services: @@ -29,6 +29,7 @@ It provides the following solvers: * Non regularized free support Wasserstein barycenters [20]. * Unbalanced OT with KL relaxation distance and barycenter [10, 25]. * Screening Sinkhorn Algorithm for OT [26]. +* JCPOT algorithm for multi-source domain adaptation with target shift [27]. Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. @@ -257,3 +258,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. (2015). [Learning with a Wasserstein Loss](http://cbcl.mit.edu/wasserstein/) Advances in Neural Information Processing Systems (NIPS). [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). [Screening Sinkhorn Algorithm for Regularized Optimal Transport](https://papers.nips.cc/paper/9386-screening-sinkhorn-algorithm-for-regularized-optimal-transport), Advances in Neural Information Processing Systems 33 (NeurIPS). + +[27] Redko I., Courty N., Flamary R., Tuia D. (2019). [Optimal Transport for Multi-source Domain Adaptation under Target Shift](http://proceedings.mlr.press/v89/redko19a.html), Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics (AISTATS) 22, 2019.
\ No newline at end of file diff --git a/examples/plot_otda_classes.py b/examples/plot_otda_classes.py index c311fbd..f028022 100644 --- a/examples/plot_otda_classes.py +++ b/examples/plot_otda_classes.py @@ -17,7 +17,6 @@ approaches currently supported in POT. import matplotlib.pylab as pl import ot - ############################################################################## # Generate data # ------------- diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py new file mode 100644 index 0000000..c495690 --- /dev/null +++ b/examples/plot_otda_jcpot.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +""" +======================== +OT for multi-source target shift +======================== + +This example introduces a target shift problem with two 2D source and 1 target domain. + +""" + +# Authors: Remi Flamary <remi.flamary@unice.fr> +# Ievgen Redko <ievgen.redko@univ-st-etienne.fr> +# +# License: MIT License + +import pylab as pl +import numpy as np +import ot +from ot.datasets import make_data_classif + +############################################################################## +# Generate data +# ------------- +n = 50 +sigma = 0.3 +np.random.seed(1985) + +p1 = .2 +dec1 = [0, 2] + +p2 = .9 +dec2 = [0, -2] + +pt = .4 +dect = [4, 0] + +xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1) +xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2) +xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect) + +all_Xr = [xs1, xs2] +all_Yr = [ys1, ys2] +# %% + +da = 1.5 + + +def plot_ax(dec, name): + pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5) + pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5) + pl.text(dec[0] - .5, dec[1] + 2, name) + + +############################################################################## +# Fig 1 : plots source and target samples +# --------------------------------------- + +pl.figure(1) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9, + label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1)) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9, + label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2)) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9, + label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt)) +pl.title('Data') + +pl.legend() +pl.axis('equal') +pl.axis('off') + +############################################################################## +# Instantiate Sinkhorn transport algorithm and fit them for all source domains +# ---------------------------------------------------------------------------- +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean') + + +def print_G(G, xs, ys, xt): + for i in range(G.shape[0]): + for j in range(G.shape[1]): + if G[i, j] > 5e-4: + if ys[i]: + c = 'b' + else: + c = 'r' + pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2) + + +############################################################################## +# Fig 2 : plot optimal couplings and transported samples +# ------------------------------------------------------ +pl.figure(2) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt) +print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + +pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') +pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + +pl.title('Independent OT') + +pl.legend() +pl.axis('equal') +pl.axis('off') + +############################################################################## +# Instantiate JCPOT adaptation algorithm and fit it +# ---------------------------------------------------------------------------- +otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) +otda.fit(all_Xr, all_Yr, xt) + +ws1 = otda.proportions_.dot(otda.log_['D2'][0]) +ws2 = otda.proportions_.dot(otda.log_['D2'][1]) + +pl.figure(3) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + +pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') +pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + +pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1])) + +pl.legend() +pl.axis('equal') +pl.axis('off') + +############################################################################## +# Run oracle transport algorithm with known proportions +# ---------------------------------------------------------------------------- +h_res = np.array([1 - pt, pt]) + +ws1 = h_res.dot(otda.log_['D2'][0]) +ws2 = h_res.dot(otda.log_['D2'][1]) + +pl.figure(4) +pl.clf() +plot_ax(dec1, 'Source 1') +plot_ax(dec2, 'Source 2') +plot_ax(dect, 'Target') +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) +pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) +pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) +pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) + +pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1') +pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2') + +pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1])) + +pl.legend() +pl.axis('equal') +pl.axis('off') +pl.show() diff --git a/ot/bregman.py b/ot/bregman.py index d5e3563..543dbaa 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -10,6 +10,7 @@ Bregman projections for regularized OT # Hicham Janati <hicham.janati@inria.fr> # Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com> # Alexander Tong <alexander.tong@yale.edu> +# Ievgen Redko <ievgen.redko@univ-st-etienne.fr> # # License: MIT License @@ -539,12 +540,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, old_v = v[i_2] v[i_2] = b[i_2] / (K[:, i_2].T.dot(u)) G[:, i_2] = u * K[:, i_2] * v[i_2] - #aviol = (G@one_m - a) - #aviol_2 = (G.T@one_n - b) + # aviol = (G@one_m - a) + # aviol_2 = (G.T@one_n - b) viol += (-old_v + v[i_2]) * K[:, i_2] * u viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] - #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) + # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) if stopThr_val <= stopThr: break @@ -940,7 +941,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, # the 10th iterations transp = G err = np.linalg.norm( - (np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2 + (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2 if log: log['err'].append(err) @@ -966,7 +967,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" - assert(len(weights) == alldistribT.shape[1]) + assert (len(weights) == alldistribT.shape[1]) return np.exp(np.dot(np.log(alldistribT), weights.T)) @@ -1108,7 +1109,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, if weights is None: weights = np.ones(A.shape[1]) / A.shape[1] else: - assert(len(weights) == A.shape[1]) + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -1206,7 +1207,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, if weights is None: weights = np.ones(n_hists) / n_hists else: - assert(len(weights) == A.shape[1]) + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -1334,7 +1335,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if weights is None: weights = np.ones(A.shape[0]) / A.shape[0] else: - assert(len(weights) == A.shape[0]) + assert (len(weights) == A.shape[0]) if log: log = {'err': []} @@ -1350,11 +1351,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, # this is equivalent to blurring on horizontal then vertical directions t = np.linspace(0, 1, A.shape[1]) [Y, X] = np.meshgrid(t, t) - xi1 = np.exp(-(X - Y)**2 / reg) + xi1 = np.exp(-(X - Y) ** 2 / reg) t = np.linspace(0, 1, A.shape[2]) [Y, X] = np.meshgrid(t, t) - xi2 = np.exp(-(X - Y)**2 / reg) + xi2 = np.exp(-(X - Y) ** 2 / reg) def K(x): return np.dot(np.dot(xi1, x), xi2) @@ -1502,6 +1503,164 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, return np.sum(K0, axis=1) +def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, + stopThr=1e-6, verbose=False, log=False, **kwargs): + r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] + + The function solves the following optimization problem: + + .. math:: + + \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k + W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) + + s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} + + where : + + - :math:`\lambda_k` is the weight of k-th source domain + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes + - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C + - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)` + + The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + + The algorithm used for solving the problem is the Iterative Bregman projections algorithm + with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution. + + Parameters + ---------- + Xs : list of K np.ndarray(nsk,d) + features of all source domains' samples + Ys : list of K np.ndarray(nsk,) + labels of all source domains' samples + Xt : np.ndarray (nt,d) + samples in the target domain + reg : float + Regularization term > 0 + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on relative change in the barycenter (>0) + log : bool, optional + record log if True + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + + Returns + ------- + h : (C,) ndarray + proportion estimation in the target domain + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. + + ''' + nbclasses = len(np.unique(Ys[0])) + nbdomains = len(Xs) + + # log dictionary + if log: + log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []} + + K = [] + M = [] + D1 = [] + D2 = [] + + # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 + for d in range(nbdomains): + dom = {} + nsk = Xs[d].shape[0] # get number of elements for this domain + dom['nbelem'] = nsk + classes = np.unique(Ys[d]) # get number of classes for this domain + + # format classes to start from 0 for convenience + if np.min(classes) != 0: + Ys[d] = Ys[d] - np.min(classes) + classes = np.unique(Ys[d]) + + # build the corresponding D_1 and D_2 matrices + Dtmp1 = np.zeros((nbclasses, nsk)) + Dtmp2 = np.zeros((nbclasses, nsk)) + + for c in classes: + nbelemperclass = np.sum(Ys[d] == c) + if nbelemperclass != 0: + Dtmp1[int(c), Ys[d] == c] = 1. + Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass) + D1.append(Dtmp1) + D2.append(Dtmp2) + + # build the cost matrix and the Gibbs kernel + Mtmp = dist(Xs[d], Xt, metric=metric) + M.append(Mtmp) + + Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype) + np.divide(Mtmp, -reg, out=Ktmp) + np.exp(Ktmp, out=Ktmp) + K.append(Ktmp) + + # uniform target distribution + a = unif(np.shape(Xt)[0]) + + cpt = 0 # iterations count + err = 1 + old_bary = np.ones((nbclasses)) + + while (err > stopThr and cpt < numItermax): + + bary = np.zeros((nbclasses)) + + # update coupling matrices for marginal constraints w.r.t. uniform target distribution + for d in range(nbdomains): + K[d] = projC(K[d], a) + other = np.sum(K[d], axis=1) + bary = bary + np.log(np.dot(D1[d], other)) / nbdomains + + bary = np.exp(bary) + + # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] + for d in range(nbdomains): + new = np.dot(D2[d].T, bary) + K[d] = projR(K[d], new) + + err = np.linalg.norm(bary - old_bary) + cpt = cpt + 1 + old_bary = bary + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + bary = bary / np.sum(bary) + + if log: + log['niter'] = cpt + log['M'] = M + log['D1'] = D1 + log['D2'] = D2 + log['gamma'] = K + return bary, log + else: + return bary + + def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): @@ -1593,7 +1752,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return pi -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -1675,14 +1835,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num M = dist(X_s, X_t, metric=metric) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) return sinkhorn_loss -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -1768,11 +1931,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 ''' if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, log=log, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -1787,11 +1953,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli return max(0, sinkhorn_div), log else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) @@ -1883,7 +2052,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res try: import bottleneck except ImportError: - warnings.warn("Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") + warnings.warn( + "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np a = np.asarray(a, dtype=np.float64) @@ -2019,8 +2189,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget - bounds_v = [(max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + bounds_v = [( + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) @@ -2069,7 +2240,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res return usc, vsc def screened_obj(usc, vsc): - part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, np.log(vsc)) + part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, + np.log(vsc)) part_IJc = np.dot(usc, vec_eps_IJc) part_IcJ = np.dot(vec_eps_IcJ, vsc) psi_epsilon = part_IJ + part_IJc + part_IcJ @@ -2091,9 +2263,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res g = np.hstack([g_u, g_v]) return f, g - #----------------------------------------------------------------------------------------------------------------# + # ----------------------------------------------------------------------------------------------------------------# # Step 2: L-BFGS-B solver # - #----------------------------------------------------------------------------------------------------------------# + # ----------------------------------------------------------------------------------------------------------------# u0, v0 = restricted_sinkhorn(u0, v0) theta0 = np.hstack([u0, v0]) @@ -7,15 +7,16 @@ Domain adaptation with optimal transport # Nicolas Courty <ncourty@irisa.fr> # Michael Perrot <michael.perrot@univ-st-etienne.fr> # Nathalie Gayraud <nat.gayraud@gmail.com> +# Ievgen Redko <ievgen.redko@univ-st-etienne.fr> # # License: MIT License import numpy as np import scipy.linalg as linalg -from .bregman import sinkhorn +from .bregman import sinkhorn, jcpot_barycenter from .lp import emd -from .utils import unif, dist, kernel, cost_normalization +from .utils import unif, dist, kernel, cost_normalization, label_normalization from .utils import check_params, BaseEstimator from .unbalanced import sinkhorn_unbalanced from .optim import cg @@ -127,7 +128,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, W = np.ones(M.shape) for (i, c) in enumerate(classes): majs = np.sum(transp[indices_labels[i]], axis=0) - majs = p * ((majs + epsilon)**(p - 1)) + majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs return transp @@ -359,8 +360,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, def loss(L, G): """Compute full loss""" - return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \ - np.sum(G * M) + eta * np.sum(sel(L - I0)**2) + return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ + np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2) def solve_L(G): """ solve L problem with fixed G (least square)""" @@ -372,10 +373,11 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False, xsi = xs1.dot(L) def f(G): - return np.sum((xsi - ns * G.dot(xt))**2) + return np.sum((xsi - ns * G.dot(xt)) ** 2) def df(G): return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) return G @@ -562,7 +564,7 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def loss(L, G): """Compute full loss""" - return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \ + return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \ np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) def solve_L_nobias(G): @@ -580,10 +582,11 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', xsi = K1.dot(L) def f(G): - return np.sum((xsi - ns * G.dot(xt))**2) + return np.sum((xsi - ns * G.dot(xt)) ** 2) def df(G): return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T) + G = cg(a, b, M, 1.0 / mu, f, df, G0=G0, numItermax=numInnerItermax, stopThr=stopInnerThr) return G @@ -783,6 +786,9 @@ class BaseTransport(BaseEstimator): transform method should always get as input a Xs parameter inverse_transform method should always get as input a Xt parameter + + transform_labels method should always get as input a ys parameter + inverse_transform_labels method should always get as input a yt parameter """ def fit(self, Xs=None, ys=None, Xt=None, yt=None): @@ -921,7 +927,6 @@ class BaseTransport(BaseEstimator): transp_Xs = [] for bi in batch_ind: - # get the nearest neighbor in the source domain D0 = dist(Xs[bi], self.xs_) idx = np.argmin(D0, axis=1) @@ -941,6 +946,50 @@ class BaseTransport(BaseEstimator): return transp_Xs + def transform_labels(self, ys=None): + """Propagate source labels ys to obtain estimated target labels as in [27] + + Parameters + ---------- + ys : array-like, shape (n_source_samples,) + The class labels + + Returns + ------- + transp_ys : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. + + References + ---------- + + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. + + """ + + # check the necessary inputs parameters are here + if check_params(ys=ys): + + ysTemp = label_normalization(np.copy(ys)) + classes = np.unique(ysTemp) + n = len(classes) + D1 = np.zeros((n, len(ysTemp))) + + # perform label propagation + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + for c in classes: + D1[int(c), ysTemp == c] = 1 + + # compute propagated labels + transp_ys = np.dot(D1, transp) + + return transp_ys.T + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): """Transports target samples Xt onto target samples Xs @@ -990,7 +1039,6 @@ class BaseTransport(BaseEstimator): transp_Xt = [] for bi in batch_ind: - D0 = dist(Xt[bi], self.xt_) idx = np.argmin(D0, axis=1) @@ -1009,8 +1057,44 @@ class BaseTransport(BaseEstimator): return transp_Xt + def inverse_transform_labels(self, yt=None): + """Propagate target labels yt to obtain estimated source labels ys + + Parameters + ---------- + yt : array-like, shape (n_target_samples,) + + Returns + ------- + transp_ys : array-like, shape (n_source_samples, nb_classes) + Estimated soft source labels. + """ + + # check the necessary inputs parameters are here + if check_params(yt=yt): + + ytTemp = label_normalization(np.copy(yt)) + classes = np.unique(ytTemp) + n = len(classes) + D1 = np.zeros((n, len(ytTemp))) + + # perform label propagation + transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + for c in classes: + D1[int(c), ytTemp == c] = 1 + + # compute propagated samples + transp_ys = np.dot(D1, transp.T) + + return transp_ys.T + class LinearTransport(BaseTransport): + """ OT linear operator between empirical distributions The function estimates the optimal linear operator that aligns the two @@ -1055,7 +1139,6 @@ class LinearTransport(BaseTransport): def __init__(self, reg=1e-8, bias=True, log=False, distribution_estimation=distribution_estimation_uniform): - self.bias = bias self.log = log self.reg = reg @@ -1136,7 +1219,6 @@ class LinearTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xs=Xs): - transp_Xs = Xs.dot(self.A_) + self.B_ return transp_Xs @@ -1170,7 +1252,6 @@ class LinearTransport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xt=Xt): - transp_Xt = Xt.dot(self.A1_) + self.B1_ return transp_Xt @@ -1231,7 +1312,6 @@ class SinkhornTransport(BaseTransport): metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=np.infty): - self.reg_e = reg_e self.max_iter = max_iter self.tol = tol @@ -1329,7 +1409,6 @@ class EMDTransport(BaseTransport): distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10, max_iter=100000): - self.metric = metric self.norm = norm self.log = log @@ -1440,7 +1519,6 @@ class SinkhornLpl1Transport(BaseTransport): metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=np.infty): - self.reg_e = reg_e self.reg_cl = reg_cl self.max_iter = max_iter @@ -1481,7 +1559,6 @@ class SinkhornLpl1Transport(BaseTransport): # check the necessary inputs parameters are here if check_params(Xs=Xs, Xt=Xt, ys=ys): - super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt) returned_ = sinkhorn_lpl1_mm( @@ -1563,7 +1640,6 @@ class SinkhornL1l2Transport(BaseTransport): metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10): - self.reg_e = reg_e self.reg_cl = reg_cl self.max_iter = max_iter @@ -1685,7 +1761,6 @@ class MappingTransport(BaseEstimator): norm=None, kernel="linear", sigma=1, max_iter=100, tol=1e-5, max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False, verbose2=False): - self.metric = metric self.norm = norm self.mu = mu @@ -1856,7 +1931,6 @@ class UnbalancedSinkhornTransport(BaseTransport): metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=10): - self.reg_e = reg_e self.reg_m = reg_m self.method = method @@ -1914,3 +1988,262 @@ class UnbalancedSinkhornTransport(BaseTransport): self.log_ = dict() return self + + +class JCPOTTransport(BaseTransport): + + """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm. + + Parameters + ---------- + reg_e : float, optional (default=1) + Entropic regularization parameter + max_iter : int, float, optional (default=10) + The minimum number of iteration before stopping the optimization + algorithm if no it has not converged + tol : float, optional (default=10e-9) + Stop threshold on error (inner sinkhorn solver) (>0) + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + log : bool, optional (default=False) + Controls the logs of the optimization algorithm + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + norm : string, optional (default=None) + If given, normalize the ground metric to avoid numerical errors that + can occur with large metric values. + distribution_estimation : callable, optional (defaults to the uniform) + The kind of distribution estimation to employ + out_of_sample_map : string, optional (default="ferradans") + The kind of out of sample mapping to apply to transport samples + from a domain into another one. Currently the only possible option is + "ferradans" which uses the method proposed in [6]. + + Attributes + ---------- + coupling_ : list of array-like objects, shape K x (n_source_samples, n_target_samples) + A set of optimal couplings between each source domain and the target domain + proportions_ : array-like, shape (n_classes,) + Estimated class proportions in the target domain + log_ : dictionary + The dictionary of log, empty dic if parameter log is not True + + References + ---------- + + .. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), + vol. 89, p.849-858, 2019. + + """ + + def __init__(self, reg_e=.1, max_iter=10, + tol=10e-9, verbose=False, log=False, + metric="sqeuclidean", + out_of_sample_map='ferradans'): + self.reg_e = reg_e + self.max_iter = max_iter + self.tol = tol + self.verbose = verbose + self.log = log + self.metric = metric + self.out_of_sample_map = out_of_sample_map + + def fit(self, Xs, ys=None, Xt=None, yt=None): + """Building coupling matrices from a list of source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : list of K array-like objects, shape K x (nk_source_samples, n_features) + A list of the training input samples. + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs, Xt=Xt, ys=ys): + + self.xs_ = Xs + self.xt_ = Xt + + returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e, + metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol, + verbose=self.verbose, log=True) + + self.coupling_ = returned_[1]['gamma'] + + # deal with the value of log + if self.log: + self.proportions_, self.log_ = returned_ + else: + self.proportions_ = returned_ + self.log_ = dict() + + return self + + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : list of K array-like objects, shape K x (nk_source_samples, n_features) + A list of the training input samples. + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + """ + + transp_Xs = [] + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]): + + # perform standard barycentric mapping for each source domain + + for coupling in self.coupling_: + transp = coupling / np.sum(coupling, 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute transported samples + transp_Xs.append(np.dot(transp, self.xt_)) + else: + + # perform out of sample mapping + indices = np.arange(Xs.shape[0]) + batch_ind = [ + indices[i:i + batch_size] + for i in range(0, len(indices), batch_size)] + + transp_Xs = [] + + for bi in batch_ind: + transp_Xs_ = [] + + # get the nearest neighbor in the sources domains + xs = np.concatenate(self.xs_, axis=0) + idx = np.argmin(dist(Xs[bi], xs), axis=1) + + # transport the source samples + for coupling in self.coupling_: + transp = coupling / np.sum( + coupling, 1)[:, None] + transp[~ np.isfinite(transp)] = 0 + transp_Xs_.append(np.dot(transp, self.xt_)) + + transp_Xs_ = np.concatenate(transp_Xs_, axis=0) + + # define the transported points + transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :] + transp_Xs.append(transp_Xs_) + + transp_Xs = np.concatenate(transp_Xs, axis=0) + + return transp_Xs + + def transform_labels(self, ys=None): + """Propagate source labels ys to obtain target labels as in [27] + + Parameters + ---------- + ys : list of K array-like objects, shape K x (nk_source_samples,) + A list of the class labels + + Returns + ------- + yt : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. + """ + + # check the necessary inputs parameters are here + if check_params(ys=ys): + yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0])) + for i in range(len(ys)): + ysTemp = label_normalization(np.copy(ys[i])) + classes = np.unique(ysTemp) + n = len(classes) + ns = len(ysTemp) + + # perform label propagation + transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + if self.log: + D1 = self.log_['D1'][i] + else: + D1 = np.zeros((n, ns)) + + for c in classes: + D1[int(c), ysTemp == c] = 1 + + # compute propagated labels + yt = yt + np.dot(D1, transp) / len(ys) + + return yt.T + + def inverse_transform_labels(self, yt=None): + """Propagate source labels ys to obtain target labels + + Parameters + ---------- + yt : array-like, shape (n_source_samples,) + The target class labels + + Returns + ------- + transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes) + A list of estimated soft source labels + """ + + # check the necessary inputs parameters are here + if check_params(yt=yt): + transp_ys = [] + ytTemp = label_normalization(np.copy(yt)) + classes = np.unique(ytTemp) + n = len(classes) + D1 = np.zeros((n, len(ytTemp))) + + for c in classes: + D1[int(c), ytTemp == c] = 1 + + for i in range(len(self.xs_)): + + # perform label propagation + transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None] + + # set nans to 0 + transp[~ np.isfinite(transp)] = 0 + + # compute propagated labels + transp_ys.append(np.dot(D1, transp.T).T) + + return transp_ys diff --git a/ot/datasets.py b/ot/datasets.py index ba0cfd9..a1ca7b6 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -30,7 +30,7 @@ def make_1D_gauss(n, m, s): 1D histogram for a gaussian distribution """ x = np.arange(n, dtype=np.float64) - h = np.exp(-(x - m)**2 / (2 * s**2)) + h = np.exp(-(x - m) ** 2 / (2 * s ** 2)) return h / h.sum() @@ -80,7 +80,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None): return make_2D_samples_gauss(n, m, sigma, random_state=None) -def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): +def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwargs): """Dataset generation for classification problems Parameters @@ -91,6 +91,8 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): number of training samples nz : float noise level (>0) + p : float + proportion of one class in the binary setting random_state : int, RandomState instance or None, optional (default=None) If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; @@ -150,6 +152,17 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs): x = x.dot(rot) + elif dataset.lower() == '2gauss_prop': + + y = np.concatenate((np.ones(int(p * n)), np.zeros(int((1 - p) * n)))) + x = np.hstack((0 * y[:, None] - 0, 1 - 2 * y[:, None])) + nz * np.random.randn(len(y), 2) + + if ('bias' not in kwargs) and ('b' not in kwargs): + kwargs['bias'] = np.array([0, 2]) + + x[:, 0] += kwargs['bias'][0] + x[:, 1] += kwargs['bias'][1] + else: x = np.array(0) y = np.array(0) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index f4f6861..8d1baa0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -2,8 +2,6 @@ """ Solvers for the original linear program OT problem - - """ # Author: Remi Flamary <remi.flamary@unice.fr> @@ -78,9 +78,10 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): thr : float, optional threshold above which the line is drawn **kwargs : dict - paameters given to the plot functions (default color is black if + parameters given to the plot functions (default color is black if nothing given) """ + if ('color' not in kwargs) and ('c' not in kwargs): kwargs['color'] = 'k' mx = G.max() diff --git a/ot/utils.py b/ot/utils.py index b71458b..c154f99 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -200,6 +200,28 @@ def dots(*args): return reduce(np.dot, args) +def label_normalization(y, start=0): + """ Transform labels to start at a given value + + Parameters + ---------- + y : array-like, shape (n, ) + The vector of labels to be normalized. + start : int + Desired value for the smallest label in y (default=0) + + Returns + ------- + y : array-like, shape (n1, ) + The input vector of labels normalized according to given start value. + """ + + diff = np.min(np.unique(y)) - start + if diff != 0: + y -= diff + return y + + def fun(f, q_in, q_out): """ Utility function for parmap with no serializing problems """ while True: diff --git a/test/test_da.py b/test/test_da.py index 2a5e50e..7d0fdda 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -5,7 +5,7 @@ # License: MIT License import numpy as np -from numpy.testing.utils import assert_allclose, assert_equal +from numpy.testing import assert_allclose, assert_equal import ot from ot.datasets import make_data_classif @@ -65,6 +65,16 @@ def test_sinkhorn_lpl1_transport_class(): transp_Xs = otda.fit_transform(Xs=Xs, ys=ys, Xt=Xt) assert_equal(transp_Xs.shape, Xs.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) + # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt) @@ -129,6 +139,16 @@ def test_sinkhorn_l1l2_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) + Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -210,6 +230,16 @@ def test_sinkhorn_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) + Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -271,6 +301,16 @@ def test_unbalanced_sinkhorn_transport_class(): transp_Xs = otda.transform(Xs=Xs) assert_equal(transp_Xs.shape, Xs.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) + Xs_new, _ = make_data_classif('3gauss', ns + 1) transp_Xs_new = otda.transform(Xs_new) @@ -353,6 +393,16 @@ def test_emd_transport_class(): transp_Xt = otda.inverse_transform(Xt=Xt) assert_equal(transp_Xt.shape, Xt.shape) + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) + Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -510,7 +560,6 @@ def test_mapping_transport_class(): def test_linear_mapping(): - ns = 150 nt = 200 @@ -528,7 +577,6 @@ def test_linear_mapping(): def test_linear_mapping_class(): - ns = 150 nt = 200 @@ -549,3 +597,95 @@ def test_linear_mapping_class(): Cst = np.cov(Xst.T) np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + +def test_jcpot_transport_class(): + """test_jcpot_transport + """ + + ns1 = 150 + ns2 = 150 + nt = 200 + + Xs1, ys1 = make_data_classif('3gauss', ns1) + Xs2, ys2 = make_data_classif('3gauss', ns2) + + Xt, yt = make_data_classif('3gauss2', nt) + + Xs = [Xs1, Xs2] + ys = [ys1, ys2] + + otda = ot.da.JCPOTTransport(reg_e=1, max_iter=10000, tol=1e-9, verbose=True, log=True) + + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + + assert hasattr(otda, "coupling_") + assert hasattr(otda, "proportions_") + assert hasattr(otda, "log_") + + # test dimensions of coupling + for i, xs in enumerate(Xs): + assert_equal(otda.coupling_[i].shape, ((xs.shape[0], Xt.shape[0]))) + + # test all margin constraints + mu_t = unif(nt) + + for i in range(len(Xs)): + # test margin constraints w.r.t. uniform target weights for each coupling matrix + assert_allclose( + np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3) + + # test margin constraints w.r.t. modified source weights for each source domain + + assert_allclose( + np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, + atol=1e-3) + + # test transform + transp_Xs = otda.transform(Xs=Xs) + [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] + + Xs_new, _ = make_data_classif('3gauss', ns1 + 1) + transp_Xs_new = otda.transform(Xs_new) + + # check that the oos method is working + assert_equal(transp_Xs_new.shape, Xs_new.shape) + + # check label propagation + transp_yt = otda.transform_labels(ys) + assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) + + # check inverse label propagation + transp_ys = otda.inverse_transform_labels(yt) + [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)] + [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)] + + +def test_jcpot_barycenter(): + """test_jcpot_barycenter + """ + + ns1 = 150 + ns2 = 150 + nt = 200 + + sigma = 0.1 + np.random.seed(1985) + + ps1 = .2 + ps2 = .9 + pt = .4 + + Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1) + Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2) + Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) + + Xs = [Xs1, Xs2] + ys = [ys1, ys2] + + prop = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean', + numItermax=10000, stopThr=1e-9, verbose=False, log=False) + + np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) |