summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py107
1 files changed, 54 insertions, 53 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index d516dfc..e37f10c 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
-Regularized Unbalanced OT
+Regularized Unbalanced OT solvers
"""
# Author: Hicham Janati <hicham.janati@inria.fr>
@@ -384,10 +384,9 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
fi = reg_m / (reg_m + reg)
- cpt = 0
err = 1.
- while (err > stopThr and cpt < numItermax):
+ for i in range(numItermax):
uprev = u
vprev = v
@@ -401,28 +400,27 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration %s' % cpt)
+ warnings.warn('Numerical errors at iteration %s' % i)
u = uprev
v = vprev
break
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
- err = 0.5 * (err_u + err_v)
- if log:
- log['err'].append(err)
+
+ err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
+ if log:
+ log['err'].append(err)
if verbose:
- if cpt % 200 == 0:
+ if i % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
- cpt += 1
+ print('{:5d}|{:8e}|'.format(i, err))
+ if err < stopThr:
+ break
if log:
- log['logu'] = np.log(u + 1e-16)
- log['logv'] = np.log(v + 1e-16)
+ log['logu'] = np.log(u + 1e-300)
+ log['logv'] = np.log(v + 1e-300)
if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
@@ -747,8 +745,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
alpha = np.zeros(dim)
beta = np.zeros(dim)
q = np.ones(dim) / dim
- while (err > stopThr and cpt < numItermax):
- qprev = q
+ for i in range(numItermax):
+ qprev = q.copy()
Kv = K.dot(v)
f_alpha = np.exp(- alpha / (reg + reg_m))
f_beta = np.exp(- beta / (reg + reg_m))
@@ -777,7 +775,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
warnings.warn('Numerical errors at iteration %s' % cpt)
q = qprev
break
- if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ if (i % 10 == 0 and not absorbing) or i == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err = abs(q - qprev).max() / max(abs(q).max(),
@@ -785,20 +783,21 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
if log:
log['err'].append(err)
if verbose:
- if cpt % 50 == 0:
+ if i % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ print('{:5d}|{:8e}|'.format(i, err))
+ if err < stopThr:
+ break
- cpt += 1
if err > stopThr:
warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
"Try a larger entropy `reg` or a lower mass `reg_m`." +
"Or a larger absorption threshold `tau`.")
if log:
- log['niter'] = cpt
- log['logu'] = np.log(u + 1e-16)
- log['logv'] = np.log(v + 1e-16)
+ log['niter'] = i
+ log['logu'] = np.log(u + 1e-300)
+ log['logv'] = np.log(v + 1e-300)
return q, log
else:
return q
@@ -882,15 +881,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
fi = reg_m / (reg_m + reg)
- v = np.ones((dim, n_hists)) / dim
- u = np.ones((dim, 1)) / dim
-
- cpt = 0
+ v = np.ones((dim, n_hists))
+ u = np.ones((dim, 1))
+ q = np.ones(dim)
err = 1.
- while (err > stopThr and cpt < numItermax):
- uprev = u
- vprev = v
+ for i in range(numItermax):
+ uprev = u.copy()
+ vprev = v.copy()
+ qprev = q.copy()
Kv = K.dot(v)
u = (A / Kv) ** fi
@@ -905,31 +904,30 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
- warnings.warn('Numerical errors at iteration %s' % cpt)
+ warnings.warn('Numerical errors at iteration %s' % i)
u = uprev
v = vprev
+ q = qprev
break
- if cpt % 10 == 0:
- # we can speed up the process by checking for the error only all
- # the 10th iterations
- err_u = abs(u - uprev).max()
- err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max()
- err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
- err = 0.5 * (err_u + err_v)
- if log:
- log['err'].append(err)
- if verbose:
- if cpt % 50 == 0:
- print(
- '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
+ # compute change in barycenter
+ err = abs(q - qprev).max()
+ err /= max(abs(q).max(), abs(qprev).max(), 1.)
+ if log:
+ log['err'].append(err)
+ # if barycenter did not change + at least 10 iterations - stop
+ if err < stopThr and i > 10:
+ break
+
+ if verbose:
+ if i % 10 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(i, err))
- cpt += 1
if log:
- log['niter'] = cpt
- log['logu'] = np.log(u + 1e-16)
- log['logv'] = np.log(v + 1e-16)
+ log['niter'] = i
+ log['logu'] = np.log(u + 1e-300)
+ log['logv'] = np.log(v + 1e-300)
return q, log
else:
return q
@@ -1002,12 +1000,14 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
if method.lower() == 'sinkhorn':
return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
+ weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
+ weights=weights,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
@@ -1015,6 +1015,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
return barycenter_unbalanced(A, M, reg, reg_m,
+ weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)