summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-07 13:36:16 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-07 13:36:16 +0200
commitd52a78d516a4cc3cedb8d36f14b686eec60d3c5b (patch)
tree0fc97f99706a800b381c4982b9fad7b7348acee6 /ot/bregman.py
parented34704eedb438821720509c5cddb745bc1b5056 (diff)
pep bregman
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py58
1 files changed, 30 insertions, 28 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 951d3ce..7f11e68 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1572,13 +1572,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
nbclasses = len(np.unique(Ys[0]))
nbdomains = len(Xs)
- # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
- all_domains = []
-
# log dictionary
if log:
- log = {'niter': 0, 'err': [], 'all_domains': []}
+ log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []}
+
+ K = []
+ M = []
+ D1 = []
+ D2 = []
+ # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
for d in range(nbdomains):
dom = {}
nsk = Xs[d].shape[0] # get number of elements for this domain
@@ -1591,28 +1594,26 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
classes = np.unique(Ys[d])
# build the corresponding D_1 and D_2 matrices
- D1 = np.zeros((nbclasses, nsk))
- D2 = np.zeros((nbclasses, nsk))
+ Dtmp1 = np.zeros((nbclasses, nsk))
+ Dtmp2 = np.zeros((nbclasses, nsk))
for c in classes:
nbelemperclass = np.sum(Ys[d] == c)
if nbelemperclass != 0:
- D1[int(c), Ys[d] == c] = 1.
- D2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
- dom['D1'] = D1
- dom['D2'] = D2
+ Dtmp1[int(c), Ys[d] == c] = 1.
+ Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
+ D1.append(Dtmp1)
+ D2.append(Dtmp2)
# build the cost matrix and the Gibbs kernel
- M = dist(Xs[d], Xt, metric=metric)
- M = M / np.median(M)
- dom['M'] = M
-
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
- dom['K'] = K
+ Mtmp = dist(Xs[d], Xt, metric=metric)
+ Mtmp = Mtmp / np.median(Mtmp)
+ M.append(M)
- all_domains.append(dom)
+ Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
+ np.divide(Mtmp, -reg, out=Ktmp)
+ np.exp(Ktmp, out=Ktmp)
+ K.append(Ktmp)
# uniform target distribution
a = unif(np.shape(Xt)[0])
@@ -1627,16 +1628,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
# update coupling matrices for marginal constraints w.r.t. uniform target distribution
for d in range(nbdomains):
- all_domains[d]['K'] = projC(all_domains[d]['K'], a)
- other = np.sum(all_domains[d]['K'], axis=1)
- bary = bary + np.log(np.dot(all_domains[d]['D1'], other)) / nbdomains
+ K[d] = projC(K[d], a)
+ other = np.sum(K[d], axis=1)
+ bary = bary + np.log(np.dot(D1[d], other)) / nbdomains
bary = np.exp(bary)
# update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
for d in range(nbdomains):
- new = np.dot(all_domains[d]['D2'].T, bary)
- all_domains[d]['K'] = projR(all_domains[d]['K'], new)
+ new = np.dot(D2[d].T, bary)
+ K[d] = projR(K[d], new)
err = np.linalg.norm(bary - old_bary)
cpt = cpt + 1
@@ -1651,14 +1652,15 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
print('{:5d}|{:8e}|'.format(cpt, err))
bary = bary / np.sum(bary)
- couplings = [all_domains[d]['K'] for d in range(nbdomains)]
if log:
log['niter'] = cpt
- log['all_domains'] = all_domains
- return couplings, bary, log
+ log['M'] = M
+ log['D1'] = D1
+ log['D2'] = D2
+ return K, bary, log
else:
- return couplings, bary
+ return K, bary
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',