From ba493aa5488507937b7f9707faa17128c9aa1872 Mon Sep 17 00:00:00 2001 From: ievred Date: Tue, 31 Mar 2020 17:36:00 +0200 Subject: readme move to bregman --- ot/bregman.py | 157 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 1 deletion(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index d5e3563..d17aaf0 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -10,6 +10,7 @@ Bregman projections for regularized OT # Hicham Janati # Mokhtar Z. Alaya # Alexander Tong +# Ievgen Redko # # License: MIT License @@ -18,7 +19,6 @@ import warnings from .utils import unif, dist from scipy.optimize import fmin_l_bfgs_b - def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" @@ -1501,6 +1501,161 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, else: return np.sum(K0, axis=1) +def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, + stopThr=1e-6, verbose=False, log=False, **kwargs): + r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] + + The function solves the following optimization problem: + + .. math:: + + \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k + W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) + + s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} + + where : + + - :math:`\lambda_k` is the weight of k-th source domain + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes + - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C + - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)` + + The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. + + The algorithm used for solving the problem is the Iterative Bregman projections algorithm + with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution. + + Parameters + ---------- + Xs : list of K np.ndarray(nsk,d) + features of all source domains' samples + Ys : list of K np.ndarray(nsk,) + labels of all source domains' samples + Xt : np.ndarray (nt,d) + samples in the target domain + reg : float + Regularization term > 0 + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on relative change in the barycenter (>0) + log : bool, optional + record log if True + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + + Returns + ------- + gamma : List of K (nsk x nt) ndarrays + Optimal transportation matrices for the given parameters for each pair of source and target domains + h : (C,) ndarray + proportion estimation in the target domain + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. + + ''' + nbclasses = len(np.unique(Ys[0])) + nbdomains = len(Xs) + + # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 + all_domains = [] + + # log dictionary + if log: + log = {'niter': 0, 'err': [], 'all_domains': []} + + for d in range(nbdomains): + dom = {} + nsk = Xs[d].shape[0] # get number of elements for this domain + dom['nbelem'] = nsk + classes = np.unique(Ys[d]) # get number of classes for this domain + + # format classes to start from 0 for convenience + if np.min(classes) != 0: + Ys[d] = Ys[d] - np.min(classes) + classes = np.unique(Ys[d]) + + # build the corresponding D_1 and D_2 matrices + D1 = np.zeros((nbclasses, nsk)) + D2 = np.zeros((nbclasses, nsk)) + + for c in classes: + nbelemperclass = np.sum(Ys[d] == c) + if nbelemperclass != 0: + D1[int(c), Ys[d] == c] = 1. + D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) + dom['D1'] = D1 + dom['D2'] = D2 + + # build the cost matrix and the Gibbs kernel + M = dist(Xs[d], Xt, metric=metric) + M = M / np.median(M) + + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) + dom['K'] = K + + all_domains.append(dom) + + # uniform target distribution + a = unif(np.shape(Xt)[0]) + + cpt = 0 # iterations count + err = 1 + old_bary = np.ones((nbclasses)) + + while (err > stopThr and cpt < numItermax): + + bary = np.zeros((nbclasses)) + + # update coupling matrices for marginal constraints w.r.t. uniform target distribution + for d in range(nbdomains): + all_domains[d]['K'] = projC(all_domains[d]['K'], a) + other = np.sum(all_domains[d]['K'], axis=1) + bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains + + bary = np.exp(bary) + + # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] + for d in range(nbdomains): + new = np.dot(all_domains[d]['D2'].T, bary) + all_domains[d]['K'] = projR(all_domains[d]['K'], new) + + err = np.linalg.norm(bary - old_bary) + cpt = cpt + 1 + old_bary = bary + + if log: + log['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + bary = bary / np.sum(bary) + couplings = [all_domains[d]['K'] for d in range(nbdomains)] + + if log: + log['niter'] = cpt + log['all_domains'] = all_domains + return couplings, bary, log + else: + return couplings, bary def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, -- cgit v1.2.3 From 547a03ef87e4aa92edc1e89ee2db04114e1a8ad5 Mon Sep 17 00:00:00 2001 From: ievred Date: Wed, 1 Apr 2020 09:13:58 +0200 Subject: fix test example add M to log --- examples/plot_otda_jcpot.py | 20 ++++++-------------- ot/bregman.py | 1 + test/test_da.py | 13 ++----------- 3 files changed, 9 insertions(+), 25 deletions(-) (limited to 'ot/bregman.py') diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py index 5e5fff8..1641fb0 100644 --- a/examples/plot_otda_jcpot.py +++ b/examples/plot_otda_jcpot.py @@ -81,11 +81,7 @@ pl.axis('off') ############################################################################## # Instantiate Sinkhorn transport algorithm and fit them for all source domains # ---------------------------------------------------------------------------- -ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-2, metric='euclidean') - -M1 = ot.dist(xs1, xt, 'euclidean') -M2 = ot.dist(xs2, xt, 'euclidean') - +ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean') def print_G(G, xs, ys, xt): for i in range(G.shape[0]): @@ -125,7 +121,7 @@ pl.axis('off') ############################################################################## # Instantiate JCPOT adaptation algorithm and fit it # ---------------------------------------------------------------------------- -otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, tol=1e-9, verbose=True, log=True) +otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) otda.fit(all_Xr, all_Yr, xt) ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2']) @@ -136,8 +132,8 @@ pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) @@ -154,10 +150,6 @@ pl.axis('off') ############################################################################## # Run oracle transport algorithm with known proportions # ---------------------------------------------------------------------------- - -otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log=True) -otda.fit(all_Xr, all_Yr, xt) - h_res = np.array([1 - pt, pt]) ws1 = h_res.dot(otda.log_['all_domains'][0]['D2']) @@ -168,8 +160,8 @@ pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], M1, reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], M2, reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) diff --git a/ot/bregman.py b/ot/bregman.py index d17aaf0..fb959e9 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1603,6 +1603,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # build the cost matrix and the Gibbs kernel M = dist(Xs[d], Xt, metric=metric) M = M / np.median(M) + dom['M'] = M K = np.empty(M.shape, dtype=M.dtype) np.divide(M, -reg, out=K) diff --git a/test/test_da.py b/test/test_da.py index 7526f30..a13550c 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -568,7 +568,7 @@ def test_jcpot_transport_class(): Xs = [Xs1, Xs2] ys = [ys1, ys2] - otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True) + otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log = True) # test its computed otda.fit(Xs=Xs, ys=ys, Xt=Xt) @@ -591,14 +591,8 @@ def test_jcpot_transport_class(): # test margin constraints w.r.t. modified source weights for each source domain - D1 = np.zeros((len(np.unique(ys[i])), len(ys[i]))) - for c in np.unique(ys[i]): - nbelemperclass = np.sum(ys[i] == c) - if nbelemperclass != 0: - D1[int(c), ys[i] == c] = 1. - assert_allclose( - np.dot(D1, np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) + np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) # test transform transp_Xs = otda.transform(Xs=Xs) @@ -609,6 +603,3 @@ def test_jcpot_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) - - -test_jcpot_transport_class() \ No newline at end of file -- cgit v1.2.3 From 9200af5d795517b0772c10bb3d16022dd1a12791 Mon Sep 17 00:00:00 2001 From: ievred Date: Thu, 2 Apr 2020 15:29:12 +0200 Subject: laplace v1 --- ot/bregman.py | 72 +++++++++++++++++++++++++++++++++---------------------- ot/datasets.py | 4 ++-- ot/lp/__init__.py | 4 +--- ot/plot.py | 3 ++- 4 files changed, 49 insertions(+), 34 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index fb959e9..951d3ce 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -19,6 +19,7 @@ import warnings from .utils import unif, dist from scipy.optimize import fmin_l_bfgs_b + def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): r""" @@ -539,12 +540,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, old_v = v[i_2] v[i_2] = b[i_2] / (K[:, i_2].T.dot(u)) G[:, i_2] = u * K[:, i_2] * v[i_2] - #aviol = (G@one_m - a) - #aviol_2 = (G.T@one_n - b) + # aviol = (G@one_m - a) + # aviol_2 = (G.T@one_n - b) viol += (-old_v + v[i_2]) * K[:, i_2] * u viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2] - #print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) + # print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) if stopThr_val <= stopThr: break @@ -715,7 +716,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, if np.abs(u).max() > tau or np.abs(v).max() > tau: if n_hists: alpha, beta = alpha + reg * \ - np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) if n_hists: @@ -940,7 +941,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, # the 10th iterations transp = G err = np.linalg.norm( - (np.sum(transp, axis=0) - b))**2 + np.linalg.norm((np.sum(transp, axis=1) - a))**2 + (np.sum(transp, axis=0) - b)) ** 2 + np.linalg.norm((np.sum(transp, axis=1) - a)) ** 2 if log: log['err'].append(err) @@ -966,7 +967,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, def geometricBar(weights, alldistribT): """return the weighted geometric mean of distributions""" - assert(len(weights) == alldistribT.shape[1]) + assert (len(weights) == alldistribT.shape[1]) return np.exp(np.dot(np.log(alldistribT), weights.T)) @@ -1108,7 +1109,7 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, if weights is None: weights = np.ones(A.shape[1]) / A.shape[1] else: - assert(len(weights) == A.shape[1]) + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -1206,7 +1207,7 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, if weights is None: weights = np.ones(n_hists) / n_hists else: - assert(len(weights) == A.shape[1]) + assert (len(weights) == A.shape[1]) if log: log = {'err': []} @@ -1334,7 +1335,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, if weights is None: weights = np.ones(A.shape[0]) / A.shape[0] else: - assert(len(weights) == A.shape[0]) + assert (len(weights) == A.shape[0]) if log: log = {'err': []} @@ -1350,11 +1351,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, # this is equivalent to blurring on horizontal then vertical directions t = np.linspace(0, 1, A.shape[1]) [Y, X] = np.meshgrid(t, t) - xi1 = np.exp(-(X - Y)**2 / reg) + xi1 = np.exp(-(X - Y) ** 2 / reg) t = np.linspace(0, 1, A.shape[2]) [Y, X] = np.meshgrid(t, t) - xi2 = np.exp(-(X - Y)**2 / reg) + xi2 = np.exp(-(X - Y) ** 2 / reg) def K(x): return np.dot(np.dot(xi1, x), xi2) @@ -1501,6 +1502,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, else: return np.sum(K0, axis=1) + def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, stopThr=1e-6, verbose=False, log=False, **kwargs): r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27] @@ -1658,6 +1660,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, else: return couplings, bary + def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): @@ -1749,7 +1752,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', return pi -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -1831,14 +1835,17 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num M = dist(X_s, X_t, metric=metric) if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) return sinkhorn_loss -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, **kwargs): r''' Compute the sinkhorn divergence loss from empirical data @@ -1924,11 +1931,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 ''' if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, + stopThr=1e-9, verbose=verbose, log=log, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) @@ -1943,11 +1953,14 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli return max(0, sinkhorn_div), log else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, **kwargs) - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, **kwargs) - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, verbose=verbose, log=log, **kwargs) + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, a, b, metric=metric, numIterMax=numIterMax, stopThr=1e-9, + verbose=verbose, log=log, **kwargs) sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b) return max(0, sinkhorn_div) @@ -2039,7 +2052,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res try: import bottleneck except ImportError: - warnings.warn("Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") + warnings.warn( + "Bottleneck module is not installed. Install it from https://pypi.org/project/Bottleneck/ for better performance.") bottleneck = np a = np.asarray(a, dtype=np.float64) @@ -2173,10 +2187,11 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res # box constraints in L-BFGS-B (see Proposition 1 in [26]) bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget - bounds_v = [(max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + bounds_v = [( + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) @@ -2225,7 +2240,8 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res return usc, vsc def screened_obj(usc, vsc): - part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, np.log(vsc)) + part_IJ = np.dot(np.dot(usc, K_IJ), vsc) - kappa * np.dot(a_I, np.log(usc)) - (1. / kappa) * np.dot(b_J, + np.log(vsc)) part_IJc = np.dot(usc, vec_eps_IJc) part_IcJ = np.dot(vec_eps_IcJ, vsc) psi_epsilon = part_IJ + part_IJc + part_IcJ @@ -2247,9 +2263,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res g = np.hstack([g_u, g_v]) return f, g - #----------------------------------------------------------------------------------------------------------------# + # ----------------------------------------------------------------------------------------------------------------# # Step 2: L-BFGS-B solver # - #----------------------------------------------------------------------------------------------------------------# + # ----------------------------------------------------------------------------------------------------------------# u0, v0 = restricted_sinkhorn(u0, v0) theta0 = np.hstack([u0, v0]) diff --git a/ot/datasets.py b/ot/datasets.py index eea9f37..a1ca7b6 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -30,7 +30,7 @@ def make_1D_gauss(n, m, s): 1D histogram for a gaussian distribution """ x = np.arange(n, dtype=np.float64) - h = np.exp(-(x - m)**2 / (2 * s**2)) + h = np.exp(-(x - m) ** 2 / (2 * s ** 2)) return h / h.sum() @@ -80,7 +80,7 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None): return make_2D_samples_gauss(n, m, sigma, random_state=None) -def make_data_classif(dataset, n, nz=.5, theta=0, p = .5, random_state=None, **kwargs): +def make_data_classif(dataset, n, nz=.5, theta=0, p=.5, random_state=None, **kwargs): """Dataset generation for classification problems Parameters diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index cdd505d..7eaa44a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -2,8 +2,6 @@ """ Solvers for the original linear program OT problem - - """ # Author: Remi Flamary @@ -18,7 +16,7 @@ from scipy.sparse import coo_matrix from .import cvx # import compiled emd -from .emd_wrap import emd_c, check_result, emd_1d_sorted +#from .emd_wrap import emd_c, check_result, emd_1d_sorted from ..utils import parmap from .cvx import barycenter from ..utils import dist diff --git a/ot/plot.py b/ot/plot.py index f403e98..ad436b4 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -78,9 +78,10 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): thr : float, optional threshold above which the line is drawn **kwargs : dict - paameters given to the plot functions (default color is black if + parameters given to the plot functions (default color is black if nothing given) """ + if ('color' not in kwargs) and ('c' not in kwargs): kwargs['color'] = 'k' mx = G.max() -- cgit v1.2.3 From d52a78d516a4cc3cedb8d36f14b686eec60d3c5b Mon Sep 17 00:00:00 2001 From: ievred Date: Tue, 7 Apr 2020 13:36:16 +0200 Subject: pep bregman --- ot/bregman.py | 58 ++++++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 28 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 951d3ce..7f11e68 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1572,13 +1572,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, nbclasses = len(np.unique(Ys[0])) nbdomains = len(Xs) - # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 - all_domains = [] - # log dictionary if log: - log = {'niter': 0, 'err': [], 'all_domains': []} + log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []} + + K = [] + M = [] + D1 = [] + D2 = [] + # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 for d in range(nbdomains): dom = {} nsk = Xs[d].shape[0] # get number of elements for this domain @@ -1591,28 +1594,26 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, classes = np.unique(Ys[d]) # build the corresponding D_1 and D_2 matrices - D1 = np.zeros((nbclasses, nsk)) - D2 = np.zeros((nbclasses, nsk)) + Dtmp1 = np.zeros((nbclasses, nsk)) + Dtmp2 = np.zeros((nbclasses, nsk)) for c in classes: nbelemperclass = np.sum(Ys[d] == c) if nbelemperclass != 0: - D1[int(c), Ys[d] == c] = 1. - D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) - dom['D1'] = D1 - dom['D2'] = D2 + Dtmp1[int(c), Ys[d] == c] = 1. + Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass) + D1.append(Dtmp1) + D2.append(Dtmp2) # build the cost matrix and the Gibbs kernel - M = dist(Xs[d], Xt, metric=metric) - M = M / np.median(M) - dom['M'] = M - - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - dom['K'] = K + Mtmp = dist(Xs[d], Xt, metric=metric) + Mtmp = Mtmp / np.median(Mtmp) + M.append(M) - all_domains.append(dom) + Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype) + np.divide(Mtmp, -reg, out=Ktmp) + np.exp(Ktmp, out=Ktmp) + K.append(Ktmp) # uniform target distribution a = unif(np.shape(Xt)[0]) @@ -1627,16 +1628,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # update coupling matrices for marginal constraints w.r.t. uniform target distribution for d in range(nbdomains): - all_domains[d]['K'] = projC(all_domains[d]['K'], a) - other = np.sum(all_domains[d]['K'], axis=1) - bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains + K[d] = projC(K[d], a) + other = np.sum(K[d], axis=1) + bary = bary + np.log(np.dot(D1[d], other)) / nbdomains bary = np.exp(bary) # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] for d in range(nbdomains): - new = np.dot(all_domains[d]['D2'].T, bary) - all_domains[d]['K'] = projR(all_domains[d]['K'], new) + new = np.dot(D2[d].T, bary) + K[d] = projR(K[d], new) err = np.linalg.norm(bary - old_bary) cpt = cpt + 1 @@ -1651,14 +1652,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, print('{:5d}|{:8e}|'.format(cpt, err)) bary = bary / np.sum(bary) - couplings = [all_domains[d]['K'] for d in range(nbdomains)] if log: log['niter'] = cpt - log['all_domains'] = all_domains - return couplings, bary, log + log['M'] = M + log['D1'] = D1 + log['D2'] = D2 + return K, bary, log else: - return couplings, bary + return K, bary def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', -- cgit v1.2.3 From 34e13d467e376e9bfee2eb15771d9308518c2adb Mon Sep 17 00:00:00 2001 From: ievred Date: Tue, 7 Apr 2020 13:44:23 +0200 Subject: pep bregman --- ot/bregman.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 7f11e68..ec81924 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -716,7 +716,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, if np.abs(u).max() > tau or np.abs(v).max() > tau: if n_hists: alpha, beta = alpha + reg * \ - np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) if n_hists: @@ -2189,7 +2189,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, res # box constraints in L-BFGS-B (see Proposition 1 in [26]) bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget bounds_v = [( max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), -- cgit v1.2.3 From 2c9f992157844d6253a302905417e86580ac6b12 Mon Sep 17 00:00:00 2001 From: ievred Date: Tue, 7 Apr 2020 13:50:11 +0200 Subject: upd --- examples/plot_otda_classes.py | 1 - examples/plot_otda_jcpot.py | 16 ++++++++-------- ot/bregman.py | 2 +- test/test_da.py | 2 +- 4 files changed, 10 insertions(+), 11 deletions(-) (limited to 'ot/bregman.py') diff --git a/examples/plot_otda_classes.py b/examples/plot_otda_classes.py index c311fbd..f028022 100644 --- a/examples/plot_otda_classes.py +++ b/examples/plot_otda_classes.py @@ -17,7 +17,6 @@ approaches currently supported in POT. import matplotlib.pylab as pl import ot - ############################################################################## # Generate data # ------------- diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py index ce6b88f..316fa8b 100644 --- a/examples/plot_otda_jcpot.py +++ b/examples/plot_otda_jcpot.py @@ -118,16 +118,16 @@ pl.axis('off') otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) otda.fit(all_Xr, all_Yr, xt) -ws1 = otda.proportions_.dot(otda.log_['all_domains'][0]['D2']) -ws2 = otda.proportions_.dot(otda.log_['all_domains'][1]['D2']) +ws1 = otda.proportions_.dot(otda.log_['D2'][0]) +ws2 = otda.proportions_.dot(otda.log_['D2'][1]) pl.figure(3) pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) @@ -146,16 +146,16 @@ pl.axis('off') # ---------------------------------------------------------------------------- h_res = np.array([1 - pt, pt]) -ws1 = h_res.dot(otda.log_['all_domains'][0]['D2']) -ws2 = h_res.dot(otda.log_['all_domains'][1]['D2']) +ws1 = h_res.dot(otda.log_['D2'][0]) +ws2 = h_res.dot(otda.log_['D2'][1]) pl.figure(4) pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['all_domains'][0]['M'], reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['all_domains'][1]['M'], reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) diff --git a/ot/bregman.py b/ot/bregman.py index ec81924..61dfa52 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1608,7 +1608,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # build the cost matrix and the Gibbs kernel Mtmp = dist(Xs[d], Xt, metric=metric) Mtmp = Mtmp / np.median(Mtmp) - M.append(M) + M.append(Mtmp) Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype) np.divide(Mtmp, -reg, out=Ktmp) diff --git a/test/test_da.py b/test/test_da.py index 372ebd4..4eaf193 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -589,7 +589,7 @@ def test_jcpot_transport_class(): # test margin constraints w.r.t. modified source weights for each source domain assert_allclose( - np.dot(otda.log_['all_domains'][i]['D1'], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, + np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3, atol=1e-3) # test transform -- cgit v1.2.3 From c68b52d1623683e86555484bf9a4875a66957bb6 Mon Sep 17 00:00:00 2001 From: ievred Date: Wed, 8 Apr 2020 10:08:47 +0200 Subject: remove laplace from jcpot --- examples/plot_otda_jcpot.py | 10 +- examples/plot_otda_laplacian.py | 127 ----------------------- ot/bregman.py | 1 - ot/da.py | 216 ---------------------------------------- test/test_da.py | 54 ---------- 5 files changed, 5 insertions(+), 403 deletions(-) delete mode 100644 examples/plot_otda_laplacian.py (limited to 'ot/bregman.py') diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py index 316fa8b..c495690 100644 --- a/examples/plot_otda_jcpot.py +++ b/examples/plot_otda_jcpot.py @@ -115,7 +115,7 @@ pl.axis('off') ############################################################################## # Instantiate JCPOT adaptation algorithm and fit it # ---------------------------------------------------------------------------- -otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) +otda = ot.da.JCPOTTransport(reg_e=1, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True) otda.fit(all_Xr, all_Yr, xt) ws1 = otda.proportions_.dot(otda.log_['D2'][0]) @@ -126,8 +126,8 @@ pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) @@ -154,8 +154,8 @@ pl.clf() plot_ax(dec1, 'Source 1') plot_ax(dec2, 'Source 2') plot_ax(dect, 'Target') -print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt) -print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt) +print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-1), xs1, ys1, xt) +print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-1), xs2, ys2, xt) pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9) pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9) pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9) diff --git a/examples/plot_otda_laplacian.py b/examples/plot_otda_laplacian.py deleted file mode 100644 index 965380c..0000000 --- a/examples/plot_otda_laplacian.py +++ /dev/null @@ -1,127 +0,0 @@ -# -*- coding: utf-8 -*- -""" -======================== -OT for domain adaptation -======================== - -This example introduces a domain adaptation in a 2D setting and OTDA -approache with Laplacian regularization. - -""" - -# Authors: Ievgen Redko - -# License: MIT License - -import matplotlib.pylab as pl -import ot - -############################################################################## -# Generate data -# ------------- - -n_source_samples = 150 -n_target_samples = 150 - -Xs, ys = ot.datasets.make_data_classif('3gauss', n_source_samples) -Xt, yt = ot.datasets.make_data_classif('3gauss2', n_target_samples) - - -############################################################################## -# Instantiate the different transport algorithms and fit them -# ----------------------------------------------------------- - -# EMD Transport -ot_emd = ot.da.EMDTransport() -ot_emd.fit(Xs=Xs, Xt=Xt) - -# Sinkhorn Transport -ot_sinkhorn = ot.da.SinkhornTransport(reg_e=.01) -ot_sinkhorn.fit(Xs=Xs, Xt=Xt) - -# EMD Transport with Laplacian regularization -ot_emd_laplace = ot.da.EMDLaplaceTransport(reg_lap=100, reg_src=1) -ot_emd_laplace.fit(Xs=Xs, Xt=Xt) - -# transport source samples onto target samples -transp_Xs_emd = ot_emd.transform(Xs=Xs) -transp_Xs_sinkhorn = ot_sinkhorn.transform(Xs=Xs) -transp_Xs_emd_laplace = ot_emd_laplace.transform(Xs=Xs) - -############################################################################## -# Fig 1 : plots source and target samples -# --------------------------------------- - -pl.figure(1, figsize=(10, 5)) -pl.subplot(1, 2, 1) -pl.scatter(Xs[:, 0], Xs[:, 1], c=ys, marker='+', label='Source samples') -pl.xticks([]) -pl.yticks([]) -pl.legend(loc=0) -pl.title('Source samples') - -pl.subplot(1, 2, 2) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', label='Target samples') -pl.xticks([]) -pl.yticks([]) -pl.legend(loc=0) -pl.title('Target samples') -pl.tight_layout() - - -############################################################################## -# Fig 2 : plot optimal couplings and transported samples -# ------------------------------------------------------ - -param_img = {'interpolation': 'nearest'} - -pl.figure(2, figsize=(15, 8)) -pl.subplot(2, 3, 1) -pl.imshow(ot_emd.coupling_, **param_img) -pl.xticks([]) -pl.yticks([]) -pl.title('Optimal coupling\nEMDTransport') - -pl.figure(2, figsize=(15, 8)) -pl.subplot(2, 3, 2) -pl.imshow(ot_sinkhorn.coupling_, **param_img) -pl.xticks([]) -pl.yticks([]) -pl.title('Optimal coupling\nSinkhornTransport') - -pl.subplot(2, 3, 3) -pl.imshow(ot_emd_laplace.coupling_, **param_img) -pl.xticks([]) -pl.yticks([]) -pl.title('Optimal coupling\nEMDLaplaceTransport') - -pl.subplot(2, 3, 4) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_emd[:, 0], transp_Xs_emd[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.xticks([]) -pl.yticks([]) -pl.title('Transported samples\nEmdTransport') -pl.legend(loc="lower left") - -pl.subplot(2, 3, 5) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_sinkhorn[:, 0], transp_Xs_sinkhorn[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.xticks([]) -pl.yticks([]) -pl.title('Transported samples\nSinkhornTransport') - -pl.subplot(2, 3, 6) -pl.scatter(Xt[:, 0], Xt[:, 1], c=yt, marker='o', - label='Target samples', alpha=0.3) -pl.scatter(transp_Xs_emd_laplace[:, 0], transp_Xs_emd_laplace[:, 1], c=ys, - marker='+', label='Transp samples', s=30) -pl.xticks([]) -pl.yticks([]) -pl.title('Transported samples\nEMDLaplaceTransport') -pl.tight_layout() - -pl.show() diff --git a/ot/bregman.py b/ot/bregman.py index 61dfa52..410ae85 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1607,7 +1607,6 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # build the cost matrix and the Gibbs kernel Mtmp = dist(Xs[d], Xt, metric=metric) - Mtmp = Mtmp / np.median(Mtmp) M.append(Mtmp) Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype) diff --git a/ot/da.py b/ot/da.py index 0fdd3be..90e9e92 100644 --- a/ot/da.py +++ b/ot/da.py @@ -748,115 +748,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, return A, b -def emd_laplace(a, b, xs, xt, M, sim, eta, alpha, - numItermax, stopThr, numInnerItermax, - stopInnerThr, log=False, verbose=False, **kwargs): - r"""Solve the optimal transport problem (OT) with Laplacian regularization - - .. math:: - \gamma = arg\min_\gamma <\gamma,M>_F + eta\Omega_\alpha(\gamma) - - s.t.\ \gamma 1 = a - - \gamma^T 1= b - - \gamma\geq 0 - - where: - - - a and b are source and target weights (sum to 1) - - xs and xt are source and target samples - - M is the (ns,nt) metric cost matrix - - :math:`\Omega_\alpha` is the Laplacian regularization term - :math:`\Omega_\alpha = (1-\alpha)/n_s^2\sum_{i,j}S^s_{i,j}\|T(\mathbf{x}^s_i)-T(\mathbf{x}^s_j)\|^2+\alpha/n_t^2\sum_{i,j}S^t_{i,j}^'\|T(\mathbf{x}^t_i)-T(\mathbf{x}^t_j)\|^2` - with :math:`S^s_{i,j}, S^t_{i,j}` denoting source and target similarity matrices and :math:`T(\cdot)` being a barycentric mapping - - The algorithm used for solving the problem is the conditional gradient algorithm as proposed in [5]. - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) - samples weights in the target domain - xs : np.ndarray (ns,d) - samples in the source domain - xt : np.ndarray (nt,d) - samples in the target domain - M : np.ndarray (ns,nt) - loss matrix - eta : float - Regularization term for Laplacian regularization - alpha : float - Regularization term for source domain's importance in regularization - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (inner emd solver) (>0) - numInnerItermax : int, optional - Max number of iterations (inner CG solver) - stopInnerThr : float, optional - Stop threshold on error (inner CG solver) (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - - Returns - ------- - gamma : (ns x nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - References - ---------- - - .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, - "Optimal Transport for Domain Adaptation," in IEEE - Transactions on Pattern Analysis and Machine Intelligence , - vol.PP, no.99, pp.1-1 - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - if sim == 'gauss': - if 'rbfparam' not in kwargs: - kwargs['rbfparam'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2)) - sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['rbfparam']) - sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['rbfparam']) - - elif sim == 'knn': - if 'nn' not in kwargs: - kwargs['nn'] = 5 - - from sklearn.neighbors import kneighbors_graph - - sS = kneighbors_graph(xs, kwargs['nn']).toarray() - sS = (sS + sS.T) / 2 - sT = kneighbors_graph(xt, kwargs['nn']).toarray() - sT = (sT + sT.T) / 2 - - lS = laplacian(sS) - lT = laplacian(sT) - - def f(G): - return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \ - + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs))))) - - def df(G): - return alpha * np.dot(lS + lS.T, np.dot(G, np.dot(xt, xt.T)))\ - + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lT + lT.T))) - - return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax, - stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log) - - def distribution_estimation_uniform(X): """estimates a uniform distribution from an array of samples X @@ -1603,113 +1494,6 @@ class SinkhornLpl1Transport(BaseTransport): return self -class EMDLaplaceTransport(BaseTransport): - - """Domain Adapatation OT method based on Earth Mover's Distance with Laplacian regularization - - Parameters - ---------- - reg_lap : float, optional (default=1) - Laplacian regularization parameter - reg_src : float, optional (default=0.5) - Source relative importance in regularization - metric : string, optional (default="sqeuclidean") - The ground metric for the Wasserstein problem - norm : string, optional (default=None) - If given, normalize the ground metric to avoid numerical errors that - can occur with large metric values. - similarity : string, optional (default="knn") - The similarity to use either knn or gaussian - max_iter : int, optional (default=100) - Max number of BCD iterations - tol : float, optional (default=1e-5) - Stop threshold on relative loss decrease (>0) - max_inner_iter : int, optional (default=10) - Max number of iterations (inner CG solver) - inner_tol : float, optional (default=1e-6) - Stop threshold on error (inner CG solver) (>0) - log : int, optional (default=False) - Controls the logs of the optimization algorithm - distribution_estimation : callable, optional (defaults to the uniform) - The kind of distribution estimation to employ - out_of_sample_map : string, optional (default="ferradans") - The kind of out of sample mapping to apply to transport samples - from a domain into another one. Currently the only possible option is - "ferradans" which uses the method proposed in [6]. - - Attributes - ---------- - coupling_ : array-like, shape (n_source_samples, n_target_samples) - The optimal coupling - - References - ---------- - .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, - "Optimal Transport for Domain Adaptation," in IEEE Transactions - on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 - """ - - def __init__(self, reg_lap=1., reg_src=1., alpha=0.5, - metric="sqeuclidean", norm=None, similarity="knn", max_iter=100, tol=1e-9, - max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False, - distribution_estimation=distribution_estimation_uniform, - out_of_sample_map='ferradans'): - self.reg_lap = reg_lap - self.reg_src = reg_src - self.alpha = alpha - self.metric = metric - self.norm = norm - self.similarity = similarity - self.max_iter = max_iter - self.tol = tol - self.max_inner_iter = max_inner_iter - self.inner_tol = inner_tol - self.log = log - self.verbose = verbose - self.distribution_estimation = distribution_estimation - self.out_of_sample_map = out_of_sample_map - - def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples - (Xs, ys) and (Xt, yt) - - Parameters - ---------- - Xs : array-like, shape (n_source_samples, n_features) - The training input samples. - ys : array-like, shape (n_source_samples,) - The class labels - Xt : array-like, shape (n_target_samples, n_features) - The training input samples. - yt : array-like, shape (n_target_samples,) - The class labels. If some target samples are unlabeled, fill the - yt's elements with -1. - - Warning: Note that, due to this convention -1 cannot be used as a - class label - - Returns - ------- - self : object - Returns self. - """ - - super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt) - - returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_, - xt=self.xt_, M=self.cost_, sim=self.similarity, eta=self.reg_lap, alpha=self.reg_src, - numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter, - stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose) - - # coupling estimation - if self.log: - self.coupling_, self.log_ = returned_ - else: - self.coupling_ = returned_ - self.log_ = dict() - return self - - class SinkhornL1l2Transport(BaseTransport): """Domain Adapatation OT method based on sinkhorn algorithm + diff --git a/test/test_da.py b/test/test_da.py index 4eaf193..1517cec 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -601,57 +601,3 @@ def test_jcpot_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) - - -def test_emd_laplace_class(): - """test_emd_laplace_transport - """ - ns = 150 - nt = 200 - - Xs, ys = make_data_classif('3gauss', ns) - Xt, yt = make_data_classif('3gauss2', nt) - - otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=True) - - # test its computed - otda.fit(Xs=Xs, ys=ys, Xt=Xt) - - assert hasattr(otda, "coupling_") - assert hasattr(otda, "log_") - - # test dimensions of coupling - assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0]))) - - # test all margin constraints - mu_s = unif(ns) - mu_t = unif(nt) - - assert_allclose( - np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3) - assert_allclose( - np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3) - - # test transform - transp_Xs = otda.transform(Xs=Xs) - [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)] - - Xs_new, _ = make_data_classif('3gauss', ns + 1) - transp_Xs_new = otda.transform(Xs_new) - - # check that the oos method is working - assert_equal(transp_Xs_new.shape, Xs_new.shape) - - # test inverse transform - transp_Xt = otda.inverse_transform(Xt=Xt) - assert_equal(transp_Xt.shape, Xt.shape) - - Xt_new, _ = make_data_classif('3gauss2', nt + 1) - transp_Xt_new = otda.inverse_transform(Xt=Xt_new) - - # check that the oos method is working - assert_equal(transp_Xt_new.shape, Xt_new.shape) - - # test fit_transform - transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt) - assert_equal(transp_Xs.shape, Xs.shape) -- cgit v1.2.3 From bc51793333994a1bf6263c9e9c111d754172fc82 Mon Sep 17 00:00:00 2001 From: ievred Date: Wed, 8 Apr 2020 14:35:00 +0200 Subject: added test barycenter + modif target --- ot/bregman.py | 2 +- test/test_da.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 410ae85..c44c141 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1528,7 +1528,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain. The algorithm used for solving the problem is the Iterative Bregman projections algorithm - with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution. + with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform target distribution. Parameters ---------- diff --git a/test/test_da.py b/test/test_da.py index b58cf51..c54dab7 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -601,3 +601,31 @@ def test_jcpot_transport_class(): # check that the oos method is working assert_equal(transp_Xs_new.shape, Xs_new.shape) + + +def test_jcpot_barycenter(): + """test_jcpot_barycenter + """ + + ns1 = 150 + ns2 = 150 + nt = 200 + + sigma = 0.1 + np.random.seed(1985) + + ps1 = .2 + ps2 = .9 + pt = .4 + + Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1) + Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2) + Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt) + + Xs = [Xs1, Xs2] + ys = [ys1, ys2] + + _, prop, = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean', + numItermax=10000, stopThr=1e-9, verbose=False, log=False) + + np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3) -- cgit v1.2.3 From 749378a50abd763c87f5cf24a4b2e0dff2a6ec6a Mon Sep 17 00:00:00 2001 From: ievred Date: Wed, 15 Apr 2020 11:12:23 +0200 Subject: fix soft labels, remove gammas from jcpot --- ot/bregman.py | 9 ++++----- ot/da.py | 40 +++++++++++++++++++++------------------- test/test_da.py | 14 +++++++++++++- 3 files changed, 38 insertions(+), 25 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index c44c141..543dbaa 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1553,8 +1553,6 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, Returns ------- - gamma : List of K (nsk x nt) ndarrays - Optimal transportation matrices for the given parameters for each pair of source and target domains h : (C,) ndarray proportion estimation in the target domain log : dict @@ -1574,7 +1572,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # log dictionary if log: - log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []} + log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []} K = [] M = [] @@ -1657,9 +1655,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, log['M'] = M log['D1'] = D1 log['D2'] = D2 - return K, bary, log + log['gamma'] = K + return bary, log else: - return K, bary + return bary def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', diff --git a/ot/da.py b/ot/da.py index 4318c0d..30e5a61 100644 --- a/ot/da.py +++ b/ot/da.py @@ -956,8 +956,8 @@ class BaseTransport(BaseEstimator): Returns ------- - transp_ys : array-like, shape (n_target_samples,) - Estimated target labels. + transp_ys : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. References ---------- @@ -985,10 +985,10 @@ class BaseTransport(BaseEstimator): for c in classes: D1[int(c), ysTemp == c] = 1 - # compute transported samples + # compute propagated labels transp_ys = np.dot(D1, transp) - return np.argmax(transp_ys, axis=0) + return transp_ys.T def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): @@ -1066,8 +1066,8 @@ class BaseTransport(BaseEstimator): Returns ------- - transp_ys : array-like, shape (n_source_samples,) - Estimated source labels. + transp_ys : array-like, shape (n_source_samples, nb_classes) + Estimated soft source labels. """ # check the necessary inputs parameters are here @@ -1087,10 +1087,10 @@ class BaseTransport(BaseEstimator): for c in classes: D1[int(c), ytTemp == c] = 1 - # compute transported samples + # compute propagated samples transp_ys = np.dot(D1, transp.T) - return np.argmax(transp_ys, axis=0) + return transp_ys.T class LinearTransport(BaseTransport): @@ -2083,13 +2083,15 @@ class JCPOTTransport(BaseTransport): returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e, metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol, - verbose=self.verbose, log=self.log) + verbose=self.verbose, log=True) + + self.coupling_ = returned_[1]['gamma'] # deal with the value of log if self.log: - self.coupling_, self.proportions_, self.log_ = returned_ + self.proportions_, self.log_ = returned_ else: - self.coupling_, self.proportions_ = returned_ + self.proportions_ = returned_ self.log_ = dict() return self @@ -2176,8 +2178,8 @@ class JCPOTTransport(BaseTransport): Returns ------- - yt : array-like, shape (n_target_samples,) - Estimated target labels. + yt : array-like, shape (n_target_samples, nb_classes) + Estimated soft target labels. """ # check the necessary inputs parameters are here @@ -2203,10 +2205,10 @@ class JCPOTTransport(BaseTransport): for c in classes: D1[int(c), ysTemp == c] = 1 - # compute transported samples + # compute propagated labels yt = yt + np.dot(D1, transp) / len(ys) - return np.argmax(yt, axis=0) + return yt.T def inverse_transform_labels(self, yt=None): """Propagate source labels ys to obtain target labels @@ -2218,8 +2220,8 @@ class JCPOTTransport(BaseTransport): Returns ------- - transp_ys : list of K array-like objects, shape K x (nk_source_samples,) - A list of estimated source labels + transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes) + A list of estimated soft source labels """ # check the necessary inputs parameters are here @@ -2241,7 +2243,7 @@ class JCPOTTransport(BaseTransport): # set nans to 0 transp[~ np.isfinite(transp)] = 0 - # compute transported labels - transp_ys.append(np.argmax(np.dot(D1, transp.T), axis=0)) + # compute propagated labels + transp_ys.append(np.dot(D1, transp.T).T) return transp_ys diff --git a/test/test_da.py b/test/test_da.py index d96046d..70296bf 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -68,10 +68,12 @@ def test_sinkhorn_lpl1_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) # test unsupervised vs semi-supervised mode otda_unsup = ot.da.SinkhornLpl1Transport() @@ -140,10 +142,12 @@ def test_sinkhorn_l1l2_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -229,10 +233,12 @@ def test_sinkhorn_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -298,10 +304,12 @@ def test_unbalanced_sinkhorn_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) Xs_new, _ = make_data_classif('3gauss', ns + 1) transp_Xs_new = otda.transform(Xs_new) @@ -388,10 +396,12 @@ def test_emd_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) assert_equal(transp_ys.shape[0], ys.shape[0]) + assert_equal(transp_ys.shape[1], len(np.unique(yt))) Xt_new, _ = make_data_classif('3gauss2', nt + 1) transp_Xt_new = otda.inverse_transform(Xt=Xt_new) @@ -645,10 +655,12 @@ def test_jcpot_transport_class(): # check label propagation transp_yt = otda.transform_labels(ys) assert_equal(transp_yt.shape[0], yt.shape[0]) + assert_equal(transp_yt.shape[1], len(np.unique(ys))) # check inverse label propagation transp_ys = otda.inverse_transform_labels(yt) - [assert_equal(x.shape, y.shape) for x, y in zip(transp_ys, ys)] + [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)] + [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)] def test_jcpot_barycenter(): -- cgit v1.2.3