From d52a78d516a4cc3cedb8d36f14b686eec60d3c5b Mon Sep 17 00:00:00 2001 From: ievred Date: Tue, 7 Apr 2020 13:36:16 +0200 Subject: pep bregman --- ot/bregman.py | 58 ++++++++++++++++++++++++++++++---------------------------- 1 file changed, 30 insertions(+), 28 deletions(-) (limited to 'ot/bregman.py') diff --git a/ot/bregman.py b/ot/bregman.py index 951d3ce..7f11e68 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -1572,13 +1572,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, nbclasses = len(np.unique(Ys[0])) nbdomains = len(Xs) - # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 - all_domains = [] - # log dictionary if log: - log = {'niter': 0, 'err': [], 'all_domains': []} + log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []} + + K = [] + M = [] + D1 = [] + D2 = [] + # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 for d in range(nbdomains): dom = {} nsk = Xs[d].shape[0] # get number of elements for this domain @@ -1591,28 +1594,26 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, classes = np.unique(Ys[d]) # build the corresponding D_1 and D_2 matrices - D1 = np.zeros((nbclasses, nsk)) - D2 = np.zeros((nbclasses, nsk)) + Dtmp1 = np.zeros((nbclasses, nsk)) + Dtmp2 = np.zeros((nbclasses, nsk)) for c in classes: nbelemperclass = np.sum(Ys[d] == c) if nbelemperclass != 0: - D1[int(c), Ys[d] == c] = 1. - D2[int(c), Ys[d] == c] = 1. / (nbelemperclass) - dom['D1'] = D1 - dom['D2'] = D2 + Dtmp1[int(c), Ys[d] == c] = 1. + Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass) + D1.append(Dtmp1) + D2.append(Dtmp2) # build the cost matrix and the Gibbs kernel - M = dist(Xs[d], Xt, metric=metric) - M = M / np.median(M) - dom['M'] = M - - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) - dom['K'] = K + Mtmp = dist(Xs[d], Xt, metric=metric) + Mtmp = Mtmp / np.median(Mtmp) + M.append(M) - all_domains.append(dom) + Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype) + np.divide(Mtmp, -reg, out=Ktmp) + np.exp(Ktmp, out=Ktmp) + K.append(Ktmp) # uniform target distribution a = unif(np.shape(Xt)[0]) @@ -1627,16 +1628,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, # update coupling matrices for marginal constraints w.r.t. uniform target distribution for d in range(nbdomains): - all_domains[d]['K'] = projC(all_domains[d]['K'], a) - other = np.sum(all_domains[d]['K'], axis=1) - bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains + K[d] = projC(K[d], a) + other = np.sum(K[d], axis=1) + bary = bary + np.log(np.dot(D1[d], other)) / nbdomains bary = np.exp(bary) # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] for d in range(nbdomains): - new = np.dot(all_domains[d]['D2'].T, bary) - all_domains[d]['K'] = projR(all_domains[d]['K'], new) + new = np.dot(D2[d].T, bary) + K[d] = projR(K[d], new) err = np.linalg.norm(bary - old_bary) cpt = cpt + 1 @@ -1651,14 +1652,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, print('{:5d}|{:8e}|'.format(cpt, err)) bary = bary / np.sum(bary) - couplings = [all_domains[d]['K'] for d in range(nbdomains)] if log: log['niter'] = cpt - log['all_domains'] = all_domains - return couplings, bary, log + log['M'] = M + log['D1'] = D1 + log['D2'] = D2 + return K, bary, log else: - return couplings, bary + return K, bary def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', -- cgit v1.2.3