summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorAlexandre Gramfort <alexandre.gramfort@m4x.org>2019-07-09 17:20:02 +0200
committerAlexandre Gramfort <alexandre.gramfort@m4x.org>2019-07-09 17:20:02 +0200
commit06fab4c1e5efbe79f91589917fba01c3fb300a87 (patch)
treeef1c832df0e3e4cb8aee044ae98751c9fac608aa /ot
parentb6fb14861accd20a323bfc5ef96c20883e4f6ce1 (diff)
more
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py112
-rw-r--r--ot/datasets.py32
-rw-r--r--ot/optim.py32
-rw-r--r--ot/plot.py6
4 files changed, 80 insertions, 102 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 13dfa3b..b67074f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -40,12 +40,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -64,7 +64,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -155,12 +155,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -176,7 +176,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
W : (nt) ndarray or float
@@ -272,12 +271,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -290,10 +289,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -453,12 +451,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,) or np.ndarray (nt,nbb)
+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -469,10 +467,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -602,11 +599,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -623,10 +620,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -823,11 +819,11 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
@@ -835,7 +831,7 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
thershold for max value in u or v for log scaling
tau : float
thershold for max value in u or v for log scaling
- warmstart : tible of vectors
+ warmstart : tuple of vectors
if given then sarting values for alpha an beta log scalings
numItermax : int, optional
Max number of iterations
@@ -850,10 +846,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1006,13 +1001,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Parameters
----------
- A : np.ndarray (d,n)
+ A : ndarray, shape (d,n)
n training distributions a_i of size d
- M : np.ndarray (d,d)
+ M : ndarray, shape (d,d)
loss matrix for OT
reg : float
Regularization term >0
- weights : np.ndarray (n,)
+ weights : ndarray, shape (n,)
Weights of each histogram a_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1102,11 +1097,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
Parameters
----------
- A : np.ndarray (n,w,h)
+ A : ndarray, shape (n, w, h)
n distributions (2D images) of size w x h
reg : float
Regularization term >0
- weights : np.ndarray (n,)
+ weights : ndarray, shape (n,)
Weights of each image on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
@@ -1119,15 +1114,13 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
log : bool, optional
record log if True
-
Returns
-------
- a : (w,h) ndarray
+ a : ndarray, shape (w, h)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
-
References
----------
@@ -1217,15 +1210,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Parameters
----------
- a : np.ndarray (d)
+ a : ndarray, shape (d)
observed distribution
- D : np.ndarray (d,n)
+ D : ndarray, shape (d, n)
dictionary matrix
- M : np.ndarray (d,d)
+ M : ndarray, shape (d, d)
loss matrix
- M0 : np.ndarray (n,n)
+ M0 : ndarray, shape (n, n)
loss matrix
- h0 : np.ndarray (n,)
+ h0 : ndarray, shape (n,)
prior on h
reg : float
Regularization term >0 (Wasserstein data fitting)
@@ -1245,7 +1238,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
Returns
-------
- a : (d,) ndarray
+ a : ndarray, shape (d,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -1325,15 +1318,15 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (ns, d)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (nt, d)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1347,7 +1340,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1415,15 +1408,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (ns, d)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (nt, d)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1437,7 +1430,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -1523,15 +1516,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
Parameters
----------
- X_s : np.ndarray (ns, d)
+ X_s : ndarray, shape (ns, d)
samples in the source domain
- X_t : np.ndarray (nt, d)
+ X_t : ndarray, shape (nt, d)
samples in the target domain
reg : float
Regularization term >0
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
@@ -1542,17 +1535,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
log : bool, optional
record log if True
-
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
-
>>> n_s = 2
>>> n_t = 4
>>> reg = 0.1
@@ -1564,7 +1555,6 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
References
----------
-
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
'''
if log:
diff --git a/ot/datasets.py b/ot/datasets.py
index e76e75d..ba0cfd9 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -17,7 +17,6 @@ def make_1D_gauss(n, m, s):
Parameters
----------
-
n : int
number of bins in the histogram
m : float
@@ -25,12 +24,10 @@ def make_1D_gauss(n, m, s):
s : float
standard deviaton of the gaussian distribution
-
Returns
-------
- h : np.array (n,)
- 1D histogram for a gaussian distribution
-
+ h : ndarray (n,)
+ 1D histogram for a gaussian distribution
"""
x = np.arange(n, dtype=np.float64)
h = np.exp(-(x - m)**2 / (2 * s**2))
@@ -44,16 +41,15 @@ def get_1D_gauss(n, m, sigma):
def make_2D_samples_gauss(n, m, sigma, random_state=None):
- """return n samples drawn from 2D gaussian N(m,sigma)
+ """Return n samples drawn from 2D gaussian N(m,sigma)
Parameters
----------
-
n : int
number of samples to make
- m : np.array (2,)
+ m : ndarray, shape (2,)
mean value of the gaussian distribution
- sigma : np.array (2,2)
+ sigma : ndarray, shape (2, 2)
covariance matrix of the gaussian distribution
random_state : int, RandomState instance or None, optional (default=None)
If int, random_state is the seed used by the random number generator;
@@ -63,9 +59,8 @@ def make_2D_samples_gauss(n, m, sigma, random_state=None):
Returns
-------
- X : np.array (n,2)
- n samples drawn from N(m,sigma)
-
+ X : ndarray, shape (n, 2)
+ n samples drawn from N(m, sigma).
"""
generator = check_random_state(random_state)
@@ -86,11 +81,10 @@ def get_2D_samples_gauss(n, m, sigma, random_state=None):
def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
- """ dataset generation for classification problems
+ """Dataset generation for classification problems
Parameters
----------
-
dataset : str
type of classification problem (see code)
n : int
@@ -105,13 +99,11 @@ def make_data_classif(dataset, n, nz=.5, theta=0, random_state=None, **kwargs):
Returns
-------
- X : np.array (n,d)
- n observation of size d
- y : np.array (n,)
- labels of the samples
-
+ X : ndarray, shape (n, d)
+ n observation of size d
+ y : ndarray, shape (n,)
+ labels of the samples.
"""
-
generator = check_random_state(random_state)
if dataset.lower() == '3gauss':
diff --git a/ot/optim.py b/ot/optim.py
index f94aceb..65baf9d 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -26,14 +26,13 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
Parameters
----------
-
- f : function
+ f : callable
loss function
- xk : np.ndarray
+ xk : ndarray
initial position
- pk : np.ndarray
+ pk : ndarray
descent direction
- gfk : np.ndarray
+ gfk : ndarray
gradient of f at xk
old_fval : float
loss value at xk
@@ -161,15 +160,15 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarray, shape (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
- G0 : np.ndarray (ns,nt), optional
+ G0 : ndarray, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
@@ -299,17 +298,17 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
Parameters
----------
- a : np.ndarray (ns,)
+ a : ndarray, shape (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : ndarrayv (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : ndarray, shape (ns, nt)
loss matrix
reg1 : float
Entropic Regularization term >0
reg2 : float
Second Regularization term >0
- G0 : np.ndarray (ns,nt), optional
+ G0 : ndarray, shape (ns, nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
@@ -326,15 +325,13 @@ def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
Returns
-------
- gamma : (ns x nt) ndarray
+ gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
-
References
----------
-
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
@@ -422,13 +419,12 @@ def solve_1d_linesearch_quad(a, b, c):
Parameters
----------
a,b,c : float
- The coefficients of the quadratic function
+ The coefficients of the quadratic function
Returns
-------
x : float
The optimal value which leads to the minimal cost
-
"""
f0 = c
df0 = b
diff --git a/ot/plot.py b/ot/plot.py
index a409d4a..f403e98 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -26,11 +26,11 @@ def plot1D_mat(a, b, M, title=''):
Parameters
----------
- a : np.array, shape (na,)
+ a : ndarray, shape (na,)
Source distribution
- b : np.array, shape (nb,)
+ b : ndarray, shape (nb,)
Target distribution
- M : np.array, shape (na,nb)
+ M : ndarray, shape (na, nb)
Matrix to plot
"""
na, nb = M.shape