summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorievred <ievgen.redko@univ-st-etienne.fr>2020-04-15 11:12:23 +0200
committerievred <ievgen.redko@univ-st-etienne.fr>2020-04-15 11:12:23 +0200
commit749378a50abd763c87f5cf24a4b2e0dff2a6ec6a (patch)
treee8aeb51b7bcf3b48fead20fa44eee154da5b8d05 /ot/bregman.py
parent1a4c264cc9b2cb0bb89840ee9175177e86eef3ef (diff)
fix soft labels, remove gammas from jcpot
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index c44c141..543dbaa 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1553,8 +1553,6 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
Returns
-------
- gamma : List of K (nsk x nt) ndarrays
- Optimal transportation matrices for the given parameters for each pair of source and target domains
h : (C,) ndarray
proportion estimation in the target domain
log : dict
@@ -1574,7 +1572,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
# log dictionary
if log:
- log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []}
+ log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []}
K = []
M = []
@@ -1657,9 +1655,10 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
log['M'] = M
log['D1'] = D1
log['D2'] = D2
- return K, bary, log
+ log['gamma'] = K
+ return bary, log
else:
- return K, bary
+ return bary
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',