summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py530
1 files changed, 356 insertions, 174 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 70e4208..2f27d58 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -7,10 +7,12 @@ Bregman projections for regularized OT
# Nicolas Courty <ncourty@irisa.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
+# Hicham Janati <hicham.janati@inria.fr>
#
# License: MIT License
import numpy as np
+import warnings
from .utils import unif, dist
@@ -31,7 +33,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, n_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
@@ -40,12 +42,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, n_b)
loss matrix
reg : float
Regularization term >0
@@ -64,7 +66,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, n_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -103,30 +105,23 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
"""
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return _sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
elif method.lower() == 'greenkhorn':
- def sink():
- return greenkhorn(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log)
+ return _greenkhorn(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log)
elif method.lower() == 'sinkhorn_stabilized':
- def sink():
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return _sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- def sink():
- return sinkhorn_epsilon_scaling(
- a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return _sinkhorn_epsilon_scaling(a, b, M, reg,
+ numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- print('Warning : unknown method using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
-
- return sink()
+ raise ValueError("Unknown method '%s'." % method)
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -146,7 +141,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, n_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
@@ -155,12 +150,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, n_b)
loss matrix
reg : float
Regularization term >0
@@ -218,35 +213,25 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
"""
-
+ b = np.asarray(b, dtype=np.float64)
+ if len(b.shape) < 2:
+ b = b[:, None]
if method.lower() == 'sinkhorn':
- def sink():
- return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return _sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
- def sink():
- return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return _sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log, **kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
- def sink():
- return sinkhorn_epsilon_scaling(
- a, b, M, reg, numItermax=numItermax,
- stopThr=stopThr, verbose=verbose, log=log, **kwargs)
+ return _sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
else:
- print('Warning : unknown method using classic Sinkhorn Knopp')
-
- def sink():
- return sinkhorn_knopp(a, b, M, reg, **kwargs)
+ raise ValueError("Unknown method '%s'." % method)
- b = np.asarray(b, dtype=np.float64)
- if len(b.shape) < 2:
- b = b[:, None]
- return sink()
-
-
-def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
- stopThr=1e-9, verbose=False, log=False, **kwargs):
+def _sinkhorn_knopp(a, b, M, reg, numItermax=1000,
+ stopThr=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -262,7 +247,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, n_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
@@ -271,12 +256,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, n_b)
loss matrix
reg : float
Regularization term >0
@@ -291,7 +276,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, n_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -331,25 +316,25 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
- Nini = len(a)
- Nfin = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
if len(b.shape) > 1:
- nbb = b.shape[1]
+ n_hists = b.shape[1]
else:
- nbb = 0
+ n_hists = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
- if nbb:
- u = np.ones((Nini, nbb)) / Nini
- v = np.ones((Nfin, nbb)) / Nfin
+ if n_hists:
+ u = np.ones((dim_a, n_hists)) / dim_a
+ v = np.ones((dim_b, n_hists)) / dim_b
else:
- u = np.ones(Nini) / Nini
- v = np.ones(Nfin) / Nfin
+ u = np.ones(dim_a) / dim_a
+ v = np.ones(dim_b) / dim_b
# print(reg)
@@ -384,13 +369,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ if n_hists:
+ np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
np.einsum('i,ij,j->j', u, K, v, out=tmp2)
- err = np.linalg.norm(tmp2 - b)**2 # violation of marginal
+ err = np.linalg.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
@@ -404,7 +388,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log['u'] = u
log['v'] = v
- if nbb: # return only loss
+ if n_hists: # return only loss
res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
@@ -419,7 +403,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
-def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
+def _greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -443,7 +427,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, n_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
@@ -451,12 +435,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
+ b : ndarray, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, n_b)
loss matrix
reg : float
Regularization term >0
@@ -469,7 +453,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, n_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -481,7 +465,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
- >>> ot.bregman.greenkhorn(a, b, M, 1)
+ >>> ot.bregman._greenkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
@@ -509,16 +493,16 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
- n = a.shape[0]
- m = b.shape[0]
+ dim_a = a.shape[0]
+ dim_b = b.shape[0]
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty_like(M)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
- u = np.full(n, 1. / n)
- v = np.full(m, 1. / m)
+ u = np.full(dim_a, 1. / dim_a)
+ v = np.full(dim_b, 1. / dim_b)
G = u[:, np.newaxis] * K * v[np.newaxis, :]
viol = G.sum(1) - a
@@ -571,8 +555,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
return G
-def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
- warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
+def _sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=20,
+ log=False, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization
@@ -588,7 +573,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, n_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
@@ -599,11 +584,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, n_b)
loss matrix
reg : float
Regularization term >0
@@ -622,7 +607,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, n_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -634,7 +619,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
- >>> ot.bregman.sinkhorn_stabilized(a,b,M,1)
+ >>> ot.bregman._sinkhorn_stabilized(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
@@ -667,10 +652,10 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# test if multiple target
if len(b.shape) > 1:
- nbb = b.shape[1]
+ n_hists = b.shape[1]
a = a[:, np.newaxis]
else:
- nbb = 0
+ n_hists = 0
# init data
na = len(a)
@@ -687,8 +672,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
else:
alpha, beta = warmstart
- if nbb:
- u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
+ if n_hists:
+ u, v = np.ones((na, n_hists)) / na, np.ones((nb, n_hists)) / nb
else:
u, v = np.ones(na) / na, np.ones(nb) / nb
@@ -720,13 +705,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
# remove numerical problems and store them in K
if np.abs(u).max() > tau or np.abs(v).max() > tau:
- if nbb:
+ if n_hists:
alpha, beta = alpha + reg * \
np.max(np.log(u), 1), beta + reg * np.max(np.log(v))
else:
alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v)
- if nbb:
- u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb
+ if n_hists:
+ u, v = np.ones((na, n_hists)) / na, np.ones((nb, n_hists)) / nb
else:
u, v = np.ones(na) / na, np.ones(nb) / nb
K = get_K(alpha, beta)
@@ -734,12 +719,15 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
if cpt % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- if nbb:
- err = np.sum((u - uprev)**2) / np.sum((u)**2) + \
- np.sum((v - vprev)**2) / np.sum((v)**2)
+ if n_hists:
+ err_u = abs(u - uprev).max()
+ err_u /= max(abs(u).max(), abs(uprev).max(), 1.)
+ err_v = abs(v - vprev).max()
+ err_v /= max(abs(v).max(), abs(vprev).max(), 1.)
+ err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
- err = np.linalg.norm((np.sum(transp, axis=0) - b))**2
+ err = np.linalg.norm((np.sum(transp, axis=0) - b))
if log:
log['err'].append(err)
@@ -766,7 +754,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
cpt = cpt + 1
if log:
- if nbb:
+ if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
logu = alpha / reg + np.log(u)
@@ -776,26 +764,28 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
log['alpha'] = alpha + reg * np.log(u)
log['beta'] = beta + reg * np.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
- if nbb:
- res = np.zeros((nbb))
- for i in range(nbb):
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
- if nbb:
- res = np.zeros((nbb))
- for i in range(nbb):
+ if n_hists:
+ res = np.zeros((n_hists))
+ for i in range(n_hists):
res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
return res
else:
return get_Gamma(alpha, beta, u, v)
-def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100,
- tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
+def _sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
+ numInnerItermax=100, tau=1e3, stopThr=1e-9,
+ warmstart=None, verbose=False, print_period=10,
+ log=False, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
@@ -812,7 +802,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
\gamma\geq 0
where :
- - M is the (ns,nt) metric cost matrix
+ - M is the (dim_a, n_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)
@@ -823,18 +813,16 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
Parameters
----------
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (dim_b,)
samples in the target domain
- M : ndarray, shape (ns, nt)
+ M : ndarray, shape (dim_a, n_b)
loss matrix
reg : float
Regularization term >0
tau : float
thershold for max value in u or v for log scaling
- tau : float
- thershold for max value in u or v for log scaling
warmstart : tuple of vectors
if given then sarting values for alpha an beta log scalings
numItermax : int, optional
@@ -852,7 +840,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, n_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -864,7 +852,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
- >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1)
+ >>> ot.bregman._sinkhorn_epsilon_scaling(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
@@ -893,8 +881,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
# init data
- na = len(a)
- nb = len(b)
+ dim_a = len(a)
+ dim_b = len(b)
# nrelative umerical precision with 64 bits
numItermin = 35
@@ -907,14 +895,14 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
- alpha, beta = np.zeros(na), np.zeros(nb)
+ alpha, beta = np.zeros(dim_a), np.zeros(dim_b)
else:
alpha, beta = warmstart
def get_K(alpha, beta):
"""log space computation"""
- return np.exp(-(M - alpha.reshape((na, 1))
- - beta.reshape((1, nb))) / reg)
+ return np.exp(-(M - alpha.reshape((dim_a, 1))
+ - beta.reshape((1, dim_b))) / reg)
# print(np.min(K))
def get_reg(n): # exponential decreasing
@@ -927,7 +915,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
regi = get_reg(cpt)
- G, logi = sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, warmstart=(
+ G, logi = _sinkhorn_stabilized(a, b, M, regi, numItermax=numInnerItermax, stopThr=1e-9, warmstart=(
alpha, beta), verbose=False, print_period=20, tau=tau, log=True)
alpha = logi['alpha']
@@ -986,8 +974,8 @@ def projC(gamma, q):
return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10))
-def barycenter(A, M, reg, weights=None, numItermax=1000,
- stopThr=1e-4, verbose=False, log=False):
+def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
+ stopThr=1e-4, verbose=False, log=False, **kwargs):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
@@ -1005,13 +993,15 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Parameters
----------
- A : ndarray, shape (d,n)
- n training distributions a_i of size d
- M : ndarray, shape (d,d)
- loss matrix for OT
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
reg : float
- Regularization term >0
- weights : ndarray, shape (n,)
+ Regularization term > 0
+ method : str (optional)
+ method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized'
+ weights : ndarray, shape (n_hists,)
Weights of each histogram a_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1025,7 +1015,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : (dim,) ndarray
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1036,8 +1026,70 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+ """
+
+ if method.lower() == 'sinkhorn':
+ return _barycenter(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose, log=log,
+ **kwargs)
+ elif method.lower() == 'sinkhorn_stabilized':
+ return _barycenter_stabilized(A, M, reg, numItermax=numItermax,
+ stopThr=stopThr, verbose=verbose,
+ log=log, **kwargs)
+ else:
+ raise ValueError("Unknown method '%s'." % method)
+
+
+def _barycenter(A, M, reg, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
"""
if weights is None:
@@ -1082,6 +1134,136 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)
+def _barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
+ stopThr=1e-4, verbose=False, log=False):
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
+ with stabilization.
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
+
+ where :
+
+ - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
+ - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
+ - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
+
+ The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
+
+ Parameters
+ ----------
+ A : ndarray, shape (dim, n_hists)
+ n_hists training distributions a_i of size dim
+ M : ndarray, shape (dim, dim)
+ loss matrix for OT
+ reg : float
+ Regularization term > 0
+ tau : float
+ thershold for max value in u or v for log scaling
+ weights : ndarray, shape (n_hists,)
+ Weights of each histogram a_i on the simplex (barycentric coodinates)
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ a : (dim,) ndarray
+ Wasserstein barycenter
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
+
+ """
+
+ dim, n_hists = A.shape
+ if weights is None:
+ weights = np.ones(n_hists) / n_hists
+ else:
+ assert(len(weights) == A.shape[1])
+
+ if log:
+ log = {'err': []}
+
+ u = np.ones((dim, n_hists)) / dim
+ v = np.ones((dim, n_hists)) / dim
+
+ # print(reg)
+ # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
+ K = np.empty(M.shape, dtype=M.dtype)
+ np.divide(M, -reg, out=K)
+ np.exp(K, out=K)
+
+ cpt = 0
+ err = 1.
+ alpha = np.zeros(dim)
+ beta = np.zeros(dim)
+ q = np.ones(dim) / dim
+ while (err > stopThr and cpt < numItermax):
+ qprev = q
+ Kv = K.dot(v)
+ u = A / (Kv + 1e-16)
+ Ktu = K.T.dot(u)
+ q = geometricBar(weights, Ktu)
+ Q = q[:, None]
+ v = Q / (Ktu + 1e-16)
+ absorbing = False
+ if (u > tau).any() or (v > tau).any():
+ absorbing = True
+ print("YEAH absorbing")
+ alpha = alpha + reg * np.log(np.max(u, 1))
+ beta = beta + reg * np.log(np.max(v, 1))
+ K = np.exp((alpha[:, None] + beta[None, :] -
+ M) / reg)
+ v = np.ones_like(v)
+ Kv = K.dot(v)
+ if (np.any(Ktu == 0.)
+ or np.any(np.isnan(u)) or np.any(np.isnan(v))
+ or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ # we have reached the machine precision
+ # come back to previous solution and quit loop
+ warnings.warn('Numerical errors at iteration %s' % cpt)
+ q = qprev
+ break
+ if (cpt % 10 == 0 and not absorbing) or cpt == 0:
+ # we can speed up the process by checking for the error only all
+ # the 10th iterations
+ err = abs(u * Kv - A).max()
+ if log:
+ log['err'].append(err)
+ if verbose:
+ if cpt % 50 == 0:
+ print(
+ '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
+ print('{:5d}|{:8e}|'.format(cpt, err))
+
+ cpt += 1
+ if err > stopThr:
+ warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
+ "Try a larger entropy `reg`" +
+ "Or a larger absorption threshold `tau`.")
+ if log:
+ log['niter'] = cpt
+ log['logu'] = np.log(u + 1e-16)
+ log['logv'] = np.log(v + 1e-16)
+ return q, log
+ else:
+ return q
+
+
def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
@@ -1101,16 +1283,16 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Parameters
----------
- A : ndarray, shape (n, w, h)
- n distributions (2D images) of size w x h
+ A : ndarray, shape (n_hists, width, height)
+ n distributions (2D images) of size width x height
reg : float
Regularization term >0
- weights : ndarray, shape (n,)
+ weights : ndarray, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
- Stop threshol on error (>0)
+ Stop threshol on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
@@ -1120,7 +1302,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Returns
-------
- a : ndarray, shape (w, h)
+ a : ndarray, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1214,15 +1396,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Parameters
----------
- a : ndarray, shape (d)
+ a : ndarray, shape (n_observed)
observed distribution
- D : ndarray, shape (d, n)
+ D : ndarray, shape (dim, dim)
dictionary matrix
- M : ndarray, shape (d, d)
+ M : ndarray, shape (dim, dim)
loss matrix
- M0 : ndarray, shape (n, n)
+ M0 : ndarray, shape (n_observed, n_observed)
loss matrix
- h0 : ndarray, shape (n,)
+ h0 : ndarray, shape (dim,)
prior on h
reg : float
Regularization term >0 (Wasserstein data fitting)
@@ -1242,7 +1424,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Returns
-------
- a : ndarray, shape (d,)
+ a : ndarray, shape (dim,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1315,22 +1497,22 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
\gamma\geq 0
where :
- - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`M` is the (dim_a, n_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (ns, d)
+ X_s : ndarray, shape (dim_a, d)
samples in the source domain
- X_t : ndarray, shape (nt, d)
+ X_t : ndarray, shape (dim_b, d)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (ns,)
+ a : ndarray, shape (dim_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (dim_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1344,7 +1526,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (dim_a, n_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1352,11 +1534,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Examples
--------
- >>> n_s = 2
- >>> n_t = 2
+ >>> n_a = 2
+ >>> n_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_a), (dim_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_b), (dim_b, 1))
>>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01, 2.26989344e-05],
[2.26989344e-05, 4.99977301e-01]])
@@ -1405,22 +1587,22 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
\gamma\geq 0
where :
- - :math:`M` is the (ns,nt) metric cost matrix
+ - :math:`M` is the (dim_a, n_b) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (nt, d)
+ X_t : ndarray, shape (n_samples_b, d)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1434,7 +1616,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1442,11 +1624,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Examples
--------
- >>> n_s = 2
- >>> n_t = 2
+ >>> n_a = 2
+ >>> n_b = 2
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
array([4.53978687e-05])
@@ -1513,22 +1695,22 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
\gamma_b\geq 0
where :
- - :math:`M` (resp. :math:`M_a, M_b`) is the (ns,nt) metric cost matrix (resp (ns, ns) and (nt, nt))
+ - :math:`M` (resp. :math:`M_a, M_b`) is the (dim_a, n_b) metric cost matrix (resp (dim_a, ns) and (dim_b, nt))
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`a` and :math:`b` are source and target weights (sum to 1)
Parameters
----------
- X_s : ndarray, shape (ns, d)
+ X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
- X_t : ndarray, shape (nt, d)
+ X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
- a : ndarray, shape (ns,)
+ a : ndarray, shape (n_samples_a,)
samples weights in the source domain
- b : ndarray, shape (nt,)
+ b : ndarray, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1541,18 +1723,18 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
Returns
-------
- gamma : ndarray, shape (ns, nt)
+ gamma : ndarray, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
- >>> n_s = 2
- >>> n_t = 4
+ >>> n_a = 2
+ >>> n_b = 4
>>> reg = 0.1
- >>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
- >>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
+ >>> X_s = np.reshape(np.arange(n_samples_a), (n_samples_a, 1))
+ >>> X_t = np.reshape(np.arange(0, n_samples_b), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
array([1.499...])