summaryrefslogtreecommitdiff
path: root/ot/bregman.py
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2019-07-09 17:20:02 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2019-07-09 17:20:02 +0200
commit06fab4c1e5efbe79f91589917fba01c3fb300a87 (patch)
treeef1c832df0e3e4cb8aee044ae98751c9fac608aa /ot/bregman.py
parentb6fb14861accd20a323bfc5ef96c20883e4f6ce1 (diff)
more
Diffstat (limited to 'ot/bregman.py')
-rw-r--r--ot/bregman.py112
1 files changed, 51 insertions, 61 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 13dfa3b..b67074f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -40,12 +40,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -64,7 +64,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -155,12 +155,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -176,7 +176,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
W : (nt) ndarray or float
@@ -272,12 +271,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -290,10 +289,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -453,12 +451,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -469,10 +467,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -602,11 +599,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -623,10 +620,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -823,11 +819,11 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -835,7 +831,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
thershold for max value in u or v for log scaling
tau : float
thershold for max value in u or v for log scaling
- warmstart : tible of vectors
+ warmstart : tuple of vectors
if given then sarting values for alpha an beta log scalings
numItermax : int, optional
Max number of iterations
@@ -850,10 +846,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1006,13 +1001,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Parameters
----------
- A : np.ndarray (d,n)
+ A : ndarray, shape (d,n)
n training distributions a_i of size d
- M : np.ndarray (d,d)
+ M : ndarray, shape (d,d)
loss matrix for OT
reg : float
Regularization term >0
- weights : np.ndarray (n,)
+ weights : ndarray, shape (n,)
Weights of each histogram a_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1102,11 +1097,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Parameters
----------
- A : np.ndarray (n,w,h)
+ A : ndarray, shape (n, w, h)
n distributions (2D images) of size w x h
reg : float
Regularization term >0
- weights : np.ndarray (n,)
+ weights : ndarray, shape (n,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1119,15 +1114,13 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
log : bool, optional
record log if True
-
Returns
-------
- a : (w,h) ndarray
+ a : ndarray, shape (w, h)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
-
References
----------
@@ -1217,15 +1210,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (d)
+ a : ndarray, shape (d)
observed distribution
- D : np.ndarray (d,n)
+ D : ndarray, shape (d, n)
dictionary matrix
- M : np.ndarray (d,d)
+ M : ndarray, shape (d, d)
loss matrix
- M0 : np.ndarray (n,n)
+ M0 : ndarray, shape (n, n)
loss matrix
- h0 : np.ndarray (n,)
+ h0 : ndarray, shape (n,)
prior on h
reg : float
Regularization term >0 (Wasserstein data fitting)
@@ -1245,7 +1238,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : ndarray, shape (d,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1325,15 +1318,15 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (ns, d)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (nt, d)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1347,7 +1340,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1415,15 +1408,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (ns, d)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (nt, d)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1437,7 +1430,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1523,15 +1516,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (ns, d)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (nt, d)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1542,17 +1535,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
-
>>> n_s = 2
>>> n_t = 4
>>> reg = 0.1
@@ -1564,7 +1555,6 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
References
----------
-
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
'''
if log: