summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py227
1 files changed, 109 insertions, 118 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 321712b..70e4208 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -16,7 +16,7 @@ from .utils import unif, dist
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
- u"""
+ r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
@@ -40,12 +40,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -64,7 +64,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -73,12 +73,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
--------
>>> import ot
- >>> a=[.5,.5]
- >>> b=[.5,.5]
- >>> M=[[0.,1.],[1.,0.]]
- >>> ot.sinkhorn(a,b,M,1)
- array([[ 0.36552929, 0.13447071],
- [ 0.13447071, 0.36552929]])
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
References
@@ -131,7 +131,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
- u"""
+ r"""
Solve the entropic regularization optimal transport problem and return the loss
The function solves the following optimization problem:
@@ -155,12 +155,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -176,7 +176,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
W : (nt) ndarray or float
@@ -188,11 +187,11 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
--------
>>> import ot
- >>> a=[.5,.5]
- >>> b=[.5,.5]
- >>> M=[[0.,1.],[1.,0.]]
- >>> ot.sinkhorn2(a,b,M,1)
- array([ 0.26894142])
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn2(a, b, M, 1)
+ array([0.26894142])
@@ -241,14 +240,14 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
b = np.asarray(b, dtype=np.float64)
if len(b.shape) < 2:
- b = b.reshape((-1, 1))
+ b = b[:, None]
return sink()
def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
stopThr=1e-9, verbose=False, log=False, **kwargs):
- """
+ r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
@@ -272,12 +271,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -290,10 +289,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -302,12 +300,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
--------
>>> import ot
- >>> a=[.5,.5]
- >>> b=[.5,.5]
- >>> M=[[0.,1.],[1.,0.]]
- >>> ot.sinkhorn(a,b,M,1)
- array([[ 0.36552929, 0.13447071],
- [ 0.13447071, 0.36552929]])
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.sinkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
References
@@ -422,7 +420,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
- """
+ r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The algorithm used is based on the paper
@@ -453,12 +451,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -469,10 +467,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -481,12 +478,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
--------
>>> import ot
- >>> a=[.5,.5]
- >>> b=[.5,.5]
- >>> M=[[0.,1.],[1.,0.]]
- >>> ot.bregman.greenkhorn(a,b,M,1)
- array([[ 0.36552929, 0.13447071],
- [ 0.13447071, 0.36552929]])
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.bregman.greenkhorn(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
References
@@ -576,7 +573,7 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
- """
+ r"""
Solve the entropic regularization OT problem with log stabilization
The function solves the following optimization problem:
@@ -602,11 +599,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -623,10 +620,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -639,8 +635,8 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.bregman.sinkhorn_stabilized(a,b,M,1)
- array([[ 0.36552929, 0.13447071],
- [ 0.13447071, 0.36552929]])
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
References
@@ -769,10 +765,14 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
cpt = cpt + 1
- # print('err=',err,' cpt=',cpt)
if log:
- log['logu'] = alpha / reg + np.log(u)
- log['logv'] = beta / reg + np.log(v)
+ if nbb:
+ alpha = alpha[:, None]
+ beta = beta[:, None]
+ logu = alpha / reg + np.log(u)
+ logv = beta / reg + np.log(v)
+ log['logu'] = logu
+ log['logv'] = logv
log['alpha'] = alpha + reg * np.log(u)
log['beta'] = beta + reg * np.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
@@ -796,7 +796,7 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100,
tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs):
- """
+ r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
@@ -823,11 +823,11 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -835,7 +835,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
thershold for max value in u or v for log scaling
tau : float
thershold for max value in u or v for log scaling
- warmstart : tible of vectors
+ warmstart : tuple of vectors
if given then sarting values for alpha an beta log scalings
numItermax : int, optional
Max number of iterations
@@ -850,10 +850,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -862,12 +861,12 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
--------
>>> import ot
- >>> a=[.5,.5]
- >>> b=[.5,.5]
- >>> M=[[0.,1.],[1.,0.]]
- >>> ot.bregman.sinkhorn_epsilon_scaling(a,b,M,1)
- array([[ 0.36552929, 0.13447071],
- [ 0.13447071, 0.36552929]])
+ >>> a=[.5, .5]
+ >>> b=[.5, .5]
+ >>> M=[[0., 1.], [1., 0.]]
+ >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1)
+ array([[0.36552929, 0.13447071],
+ [0.13447071, 0.36552929]])
References
@@ -989,7 +988,7 @@ def projC(gamma, q):
def barycenter(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False):
- """Compute the entropic regularized wasserstein barycenter of distributions A
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
The function solves the following optimization problem:
@@ -1006,13 +1005,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Parameters
----------
- A : np.ndarray (d,n)
+ A : ndarray, shape (d,n)
n training distributions a_i of size d
- M : np.ndarray (d,d)
+ M : ndarray, shape (d,d)
loss matrix for OT
reg : float
Regularization term >0
- weights : np.ndarray (n,)
+ weights : ndarray, shape (n,)
Weights of each histogram a_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1084,7 +1083,7 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False):
- """Compute the entropic regularized wasserstein barycenter of distributions A
+ r"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
The function solves the following optimization problem:
@@ -1102,11 +1101,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Parameters
----------
- A : np.ndarray (n,w,h)
+ A : ndarray, shape (n, w, h)
n distributions (2D images) of size w x h
reg : float
Regularization term >0
- weights : np.ndarray (n,)
+ weights : ndarray, shape (n,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1119,15 +1118,13 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
log : bool, optional
record log if True
-
Returns
-------
- a : (w,h) ndarray
+ a : ndarray, shape (w, h)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
-
References
----------
@@ -1195,7 +1192,7 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
stopThr=1e-3, verbose=False, log=False):
- """
+ r"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
The function solve the following optimization problem:
@@ -1217,15 +1214,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (d)
+ a : ndarray, shape (d)
observed distribution
- D : np.ndarray (d,n)
+ D : ndarray, shape (d, n)
dictionary matrix
- M : np.ndarray (d,d)
+ M : ndarray, shape (d, d)
loss matrix
- M0 : np.ndarray (n,n)
+ M0 : ndarray, shape (n, n)
loss matrix
- h0 : np.ndarray (n,)
+ h0 : ndarray, shape (n,)
prior on h
reg : float
Regularization term >0 (Wasserstein data fitting)
@@ -1245,7 +1242,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : ndarray, shape (d,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1302,7 +1299,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
- '''
+ r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data
@@ -1325,15 +1322,15 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (ns, d)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (nt, d)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1347,7 +1344,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1360,10 +1357,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
- >>> emp_sinkhorn = empirical_sinkhorn(X_s, X_t, reg, verbose=False)
- >>> print(emp_sinkhorn)
- >>> [[4.99977301e-01 2.26989344e-05]
- [2.26989344e-05 4.99977301e-01]]
+ >>> empirical_sinkhorn(X_s, X_t, reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE
+ array([[4.99977301e-01, 2.26989344e-05],
+ [2.26989344e-05, 4.99977301e-01]])
References
@@ -1392,7 +1388,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
- '''
+ r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
@@ -1416,15 +1412,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (ns, d)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (nt, d)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1438,7 +1434,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1451,9 +1447,8 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
- >>> loss_sinkhorn = empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
- >>> print(loss_sinkhorn)
- >>> [4.53978687e-05]
+ >>> empirical_sinkhorn2(X_s, X_t, reg, verbose=False)
+ array([4.53978687e-05])
References
@@ -1482,7 +1477,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, verbose=False, log=False, **kwargs):
- '''
+ r'''
Compute the sinkhorn divergence loss from empirical data
The function solves the following optimization problems and return the
@@ -1525,15 +1520,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (ns, d)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (nt, d)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1544,30 +1539,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
-
>>> n_s = 2
>>> n_t = 4
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_s), (n_s, 1))
>>> X_t = np.reshape(np.arange(0, n_t), (n_t, 1))
- >>> emp_sinkhorn_div = empirical_sinkhorn_divergence(X_s, X_t, reg)
- >>> print(emp_sinkhorn_div)
- >>> [2.99977435]
+ >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS
+ array([1.499...])
References
----------
-
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
'''
if log: