diff options
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r-- | ot/unbalanced.py | 66 |
1 files changed, 47 insertions, 19 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 23f6607..66a8830 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -14,7 +14,7 @@ from scipy.special import logsumexp # from .utils import unif, dist -def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, +def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem @@ -120,20 +120,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000, """ if method.lower() == 'sinkhorn': - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') - return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, + return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) @@ -261,8 +261,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', else: raise ValueError('Unknown method %s.' % method) - -def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, +# TODO: update the doc +def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -349,6 +349,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, """ a = np.asarray(a, dtype=np.float64) + print(a) b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) @@ -376,24 +377,39 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, else: u = np.ones(dim_a) / dim_a v = np.ones(dim_b) / dim_b + u = np.ones(dim_a) + v = np.ones(dim_b) # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) + np.true_divide(M, -reg, out=K) np.exp(K, out=K) - - fi = reg_m / (reg_m + reg) + + if div == "KL": + fi = reg_m / (reg_m + reg) + elif div == "TV": + fi = reg_m / reg err = 1. + + dx = np.ones(dim_a) / dim_a + dy = np.ones(dim_b) / dim_b + z = 1 for i in range(numItermax): uprev = u vprev = v - Kv = K.dot(v) - u = (a / Kv) ** fi - Ktu = K.T.dot(u) - v = (b / Ktu) ** fi + Kv = z*K.dot(v*dy) + u = scaling_iter_prox(Kv, a, fi, div) + #u = (a / Kv) ** fi + Ktu = z*K.T.dot(u*dx) + v = scaling_iter_prox(Ktu, b, fi, div) + #v = (b / Ktu) ** fi + #print(v*dy) + z = np.dot((u*dx).T, np.dot(K,v*dy))/0.35 + print(z) + if (np.any(Ktu == 0.) or np.any(np.isnan(u)) or np.any(np.isnan(v)) @@ -434,12 +450,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000, if log: return u[:, None] * K * v[None, :], log else: - return u[:, None] * K * v[None, :] - + return z*u[:, None] * K * v[None, :] -def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000, - stopThr=1e-6, verbose=False, log=False, - **kwargs): +# TODO: update the doc +def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5, + numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss @@ -564,7 +580,10 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 np.divide(M, -reg, out=K) np.exp(K, out=K) - fi = reg_m / (reg_m + reg) + if div == "KL": + fi = reg_m / (reg_m + reg) + elif div == "TV": + fi = reg_m / reg cpt = 0 err = 1. @@ -650,6 +669,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000 else: return ot_matrix +def scaling_iter_prox(s, p, fi, div): + if div == "KL": + return (p / s) ** fi + elif div == "TV": + return np.minimum(s*np.exp(fi), np.maximum(s*np.exp(-fi), p)) / s + else: + raise ValueError("Unknown divergence '%s'." % div) + + def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax=1000, stopThr=1e-6, |