summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py17
1 files changed, 8 insertions, 9 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index fc20175..c06af2f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -2525,8 +2525,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
# geometric interpolation
delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
- K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0)
-
+ K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0
err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
@@ -2656,16 +2655,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
classes = nx.unique(Ys[d])
# build the corresponding D_1 and D_2 matrices
- Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
- Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
+ Dtmp1 = np.zeros((nbclasses, nsk))
+ Dtmp2 = np.zeros((nbclasses, nsk))
for c in classes:
- nbelemperclass = nx.sum(Ys[d] == c)
+ nbelemperclass = float(nx.sum(Ys[d] == c))
if nbelemperclass != 0:
- Dtmp1[int(c), Ys[d] == c] = 1.
- Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
- D1.append(Dtmp1)
- D2.append(Dtmp2)
+ Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1.
+ Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass)
+ D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0]))
+ D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0]))
# build the cost matrix and the Gibbs kernel
Mtmp = dist(Xs[d], Xt, metric=metric)