diff options
Diffstat (limited to 'ot/bregman.py')
-rw-r--r-- | ot/bregman.py | 112 |
1 files changed, 51 insertions, 61 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 13dfa3b..f39145d 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -40,12 +40,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, Parameters ---------- - a : np.ndarray (ns,) + a : ndarray, shape (ns,) samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) + b : ndarray, shape (nt,) or ndarray, shape (nt, nbb) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + M : ndarray, shape (ns, nt) loss matrix reg : float Regularization term >0 @@ -64,7 +64,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, Returns ------- - gamma : (ns x nt) ndarray + gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -155,12 +155,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, Parameters ---------- - a : np.ndarray (ns,) + a : ndarray, shape (ns,) samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) + b : ndarray, shape (nt,) or ndarray, shape (nt, nbb) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + M : ndarray, shape (ns, nt) loss matrix reg : float Regularization term >0 @@ -176,7 +176,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, log : bool, optional record log if True - Returns ------- W : (nt) ndarray or float @@ -272,12 +271,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, Parameters ---------- - a : np.ndarray (ns,) + a : ndarray, shape (ns,) samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) + b : ndarray, shape (nt,) or ndarray, shape (nt, nbb) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + M : ndarray, shape (ns, nt) loss matrix reg : float Regularization term >0 @@ -290,10 +289,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, log : bool, optional record log if True - Returns ------- - gamma : (ns x nt) ndarray + gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -453,12 +451,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= Parameters ---------- - a : np.ndarray (ns,) + a : ndarray, shape (ns,) samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) + b : ndarray, shape (nt,) or ndarray, shape (nt, nbb) samples in the target domain, compute sinkhorn with multiple targets and fixed M if b is a matrix (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) + M : ndarray, shape (ns, nt) loss matrix reg : float Regularization term >0 @@ -469,10 +467,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log= log : bool, optional record log if True - Returns ------- - gamma : (ns x nt) ndarray + gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -602,11 +599,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, Parameters ---------- - a : np.ndarray (ns,) + a : ndarray, shape (ns,) samples weights in the source domain - b : np.ndarray (nt,) + b : ndarray, shape (nt,) samples in the target domain - M : np.ndarray (ns,nt) + M : ndarray, shape (ns, nt) loss matrix reg : float Regularization term >0 @@ -623,10 +620,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, log : bool, optional record log if True - Returns ------- - gamma : (ns x nt) ndarray + gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -823,11 +819,11 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne Parameters ---------- - a : np.ndarray (ns,) + a : ndarray, shape (ns,) samples weights in the source domain - b : np.ndarray (nt,) + b : ndarray, shape (nt,) samples in the target domain - M : np.ndarray (ns,nt) + M : ndarray, shape (ns, nt) loss matrix reg : float Regularization term >0 @@ -835,7 +831,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne thershold for max value in u or v for log scaling tau : float thershold for max value in u or v for log scaling - warmstart : tible of vectors + warmstart : tuple of vectors if given then sarting values for alpha an beta log scalings numItermax : int, optional Max number of iterations @@ -850,10 +846,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne log : bool, optional record log if True - Returns ------- - gamma : (ns x nt) ndarray + gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1006,13 +1001,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, Parameters ---------- - A : np.ndarray (d,n) + A : ndarray, shape (d,n) n training distributions a_i of size d - M : np.ndarray (d,d) + M : ndarray, shape (d,d) loss matrix for OT reg : float Regularization term >0 - weights : np.ndarray (n,) + weights : ndarray, shape (n,) Weights of each histogram a_i on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations @@ -1102,11 +1097,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 Parameters ---------- - A : np.ndarray (n,w,h) + A : ndarray, shape (n, w, h) n distributions (2D images) of size w x h reg : float Regularization term >0 - weights : np.ndarray (n,) + weights : ndarray, shape (n,) Weights of each image on the simplex (barycentric coodinates) numItermax : int, optional Max number of iterations @@ -1119,15 +1114,13 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1 log : bool, optional record log if True - Returns ------- - a : (w,h) ndarray + a : ndarray, shape (w, h) 2D Wasserstein barycenter log : dict log dictionary return only if log==True in parameters - References ---------- @@ -1217,15 +1210,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Parameters ---------- - a : np.ndarray (d) + a : ndarray, shape (d) observed distribution - D : np.ndarray (d,n) + D : ndarray, shape (d, n) dictionary matrix - M : np.ndarray (d,d) + M : ndarray, shape (d, d) loss matrix - M0 : np.ndarray (n,n) + M0 : ndarray, shape (n, n) loss matrix - h0 : np.ndarray (n,) + h0 : ndarray, shape (n,) prior on h reg : float Regularization term >0 (Wasserstein data fitting) @@ -1245,7 +1238,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, Returns ------- - a : (d,) ndarray + a : ndarray, shape (d,) Wasserstein barycenter log : dict log dictionary return only if log==True in parameters @@ -1325,15 +1318,15 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI Parameters ---------- - X_s : np.ndarray (ns, d) + X_s : ndarray, shape (ns, d) samples in the source domain - X_t : np.ndarray (nt, d) + X_t : ndarray, shape (nt, d) samples in the target domain reg : float Regularization term >0 - a : np.ndarray (ns,) + a : ndarray, shape (ns,) samples weights in the source domain - b : np.ndarray (nt,) + b : ndarray, shape (nt,) samples weights in the target domain numItermax : int, optional Max number of iterations @@ -1347,7 +1340,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI Returns ------- - gamma : (ns x nt) ndarray + gamma : ndarray, shape (ns, nt) Regularized optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1415,15 +1408,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Parameters ---------- - X_s : np.ndarray (ns, d) + X_s : ndarray, shape (ns, d) samples in the source domain - X_t : np.ndarray (nt, d) + X_t : ndarray, shape (nt, d) samples in the target domain reg : float Regularization term >0 - a : np.ndarray (ns,) + a : ndarray, shape (ns,) samples weights in the source domain - b : np.ndarray (nt,) + b : ndarray, shape (nt,) samples weights in the target domain numItermax : int, optional Max number of iterations @@ -1437,7 +1430,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Returns ------- - gamma : (ns x nt) ndarray + gamma : ndarray, shape (ns, nt) Regularized optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters @@ -1523,15 +1516,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli Parameters ---------- - X_s : np.ndarray (ns, d) + X_s : ndarray, shape (ns, d) samples in the source domain - X_t : np.ndarray (nt, d) + X_t : ndarray, shape (nt, d) samples in the target domain reg : float Regularization term >0 - a : np.ndarray (ns,) + a : ndarray, shape (ns,) samples weights in the source domain - b : np.ndarray (nt,) + b : ndarray, shape (nt,) samples weights in the target domain numItermax : int, optional Max number of iterations @@ -1542,17 +1535,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli log : bool, optional record log if True - Returns ------- - gamma : (ns x nt) ndarray + gamma : ndarray, shape (ns, nt) Regularized optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters Examples -------- - >>> n_s = 2 >>> n_t = 4 >>> reg = 0.1 @@ -1564,7 +1555,6 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli References ---------- - .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018 ''' if log: |