summaryrefslogtreecommitdiff
path: root/ot/unbalanced.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/unbalanced.py')
-rw-r--r--ot/unbalanced.py66
1 files changed, 47 insertions, 19 deletions
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 23f6607..66a8830 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -14,7 +14,7 @@ from scipy.special import logsumexp
# from .utils import unif, dist
-def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
+def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', div = "TV", numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the unbalanced entropic regularization optimal transport problem
@@ -120,20 +120,20 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
"""
if method.lower() == 'sinkhorn':
- return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
+ return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
- return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
+ return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
@@ -261,8 +261,8 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
else:
raise ValueError('Unknown method %s.' % method)
-
-def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
+# TODO: update the doc
+def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, div="KL", numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and return the loss
@@ -349,6 +349,7 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
"""
a = np.asarray(a, dtype=np.float64)
+ print(a)
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
@@ -376,24 +377,39 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
else:
u = np.ones(dim_a) / dim_a
v = np.ones(dim_b) / dim_b
+ u = np.ones(dim_a)
+ v = np.ones(dim_b)
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
+ np.true_divide(M, -reg, out=K)
np.exp(K, out=K)
-
- fi = reg_m / (reg_m + reg)
+
+ if div == "KL":
+ fi = reg_m / (reg_m + reg)
+ elif div == "TV":
+ fi = reg_m / reg
err = 1.
+
+ dx = np.ones(dim_a) / dim_a
+ dy = np.ones(dim_b) / dim_b
+ z = 1
for i in range(numItermax):
uprev = u
vprev = v
- Kv = K.dot(v)
- u = (a / Kv) ** fi
- Ktu = K.T.dot(u)
- v = (b / Ktu) ** fi
+ Kv = z*K.dot(v*dy)
+ u = scaling_iter_prox(Kv, a, fi, div)
+ #u = (a / Kv) ** fi
+ Ktu = z*K.T.dot(u*dx)
+ v = scaling_iter_prox(Ktu, b, fi, div)
+ #v = (b / Ktu) ** fi
+ #print(v*dy)
+ z = np.dot((u*dx).T, np.dot(K,v*dy))/0.35
+ print(z)
+
if (np.any(Ktu == 0.)
or np.any(np.isnan(u)) or np.any(np.isnan(v))
@@ -434,12 +450,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
if log:
return u[:, None] * K * v[None, :], log
else:
- return u[:, None] * K * v[None, :]
-
+ return z*u[:, None] * K * v[None, :]
-def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
- stopThr=1e-6, verbose=False, log=False,
- **kwargs):
+# TODO: update the doc
+def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, div = "KL", tau=1e5,
+ numItermax=1000, stopThr=1e-6,
+ verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport
problem and return the loss
@@ -564,7 +580,10 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- fi = reg_m / (reg_m + reg)
+ if div == "KL":
+ fi = reg_m / (reg_m + reg)
+ elif div == "TV":
+ fi = reg_m / reg
cpt = 0
err = 1.
@@ -650,6 +669,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
else:
return ot_matrix
+def scaling_iter_prox(s, p, fi, div):
+ if div == "KL":
+ return (p / s) ** fi
+ elif div == "TV":
+ return np.minimum(s*np.exp(fi), np.maximum(s*np.exp(-fi), p)) / s
+ else:
+ raise ValueError("Unknown divergence '%s'." % div)
+
+
def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
numItermax=1000, stopThr=1e-6,