diff options
-rw-r--r-- | ot/bregman.py | 14 | ||||
-rw-r--r-- | ot/datasets.py | 71 | ||||
-rw-r--r-- | ot/utils.py | 23 |
3 files changed, 83 insertions, 25 deletions
diff --git a/ot/bregman.py b/ot/bregman.py index 9183bea..ad9a67a 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -276,14 +276,6 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal The optimization problem is solved suing the algorithm described in [4] - - distrib : distribution to unmix - D : Dictionnary - M : Metric matrix in the space of the distributions to unmix - M0 : Metric matrix in the space of the 'abundance values' to solve for - h0 : prior on solution (generally uniform distribution) - reg,reg0 : transport regularizations - alpha : how much should we trust the prior ? ([0,1]) Parameters ---------- @@ -300,7 +292,9 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal reg: float Regularization term >0 (Wasserstein data fitting) reg0: float - Regularization term >0 (Wasserstein reg with h0) + Regularization term >0 (Wasserstein reg with h0) + alpha: float + How much should we trust the prior ([0,1]) numItermax: int, optional Max number of iterations stopThr: float, optional @@ -318,7 +312,7 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal log: dict log dictionary return only if log==True in parameters - References + References ---------- .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016. diff --git a/ot/datasets.py b/ot/datasets.py index f22e345..6388d94 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -8,14 +8,50 @@ import scipy as sp def get_1D_gauss(n,m,s): - "return a 1D histogram for a gaussian distribution (n bins, mean m and std s) " + """return a 1D histogram for a gaussian distribution (n bins, mean m and std s) + + Parameters + ---------- + + n : int + number of bins in the histogram + m : float + mean value of the gaussian distribution + s : float + standard deviaton of the gaussian distribution + + + Returns + ------- + h : np.array (n,) + 1D histogram for a gaussian distribution + + """ x=np.arange(n,dtype=np.float64) h=np.exp(-(x-m)**2/(2*s^2)) return h/h.sum() def get_2D_samples_gauss(n,m,sigma): - "return samples from 2D gaussian (n samples, mean m and cov sigma) " + """return n samples drawn from 2D gaussian N(m,sigma) + + Parameters + ---------- + + n : int + number of bins in the histogram + m : np.array (2,) + mean value of the gaussian distribution + sigma : np.array (2,2) + covariance matrix of the gaussian distribution + + + Returns + ------- + X : np.array (n,2) + n samples drawn from N(m,sigma) + + """ if np.isscalar(sigma): sigma=np.array([sigma,]) if len(sigma)>1: @@ -26,8 +62,26 @@ def get_2D_samples_gauss(n,m,sigma): return res def get_data_classif(dataset,n,nz=.5,**kwargs): - """ - dataset generation + """ dataset generation for classification problems + + Parameters + ---------- + + dataset : str + type of classification problem (see code) + n : int + number of training samples + nz : float + noise level (>0) + + + Returns + ------- + X : np.array (n,d) + n observation of size d + y : np.array (n,) + labels of the samples + """ if dataset.lower()=='3gauss': y=np.floor((np.arange(n)*1.0/n*3))+1 @@ -50,15 +104,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs): x[y==3,0]=2. ; x[y==3,1]=0 x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2) - x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) -# elif dataset.lower()=='sinreg': -# -# x=np.random.rand(n,1) -# y=4*x+np.sin(2*np.pi*x)+nz*np.random.randn(n,1) - + x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2) else: x=0 y=0 print("unknown dataset") - return x,y
\ No newline at end of file + return x,y.astype(int)
\ No newline at end of file diff --git a/ot/utils.py b/ot/utils.py index e5ec864..2110c01 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """ Various function that can be usefull """ @@ -6,7 +7,21 @@ from scipy.spatial.distance import cdist def unif(n): - """ return a uniform histogram of length n (simplex) """ + """ return a uniform histogram of length n (simplex) + + Parameters + ---------- + + n : int + number of bins in the histogram + + Returns + ------- + h : np.array (n,) + histogram of length n such that h_i=1/n for all i + + + """ return np.ones((n,))/n @@ -22,9 +37,9 @@ def dist(x1,x2=None,metric='sqeuclidean'): matrix with n2 samples of size d (if None then x2=x1) metric : str, fun, optional name of the metric to be computed (full list in the doc of scipy), If a string, - the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, + the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’, - ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’, + ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’, ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’. @@ -68,5 +83,5 @@ def dist0(n,method='lin_square'): def dots(*args): - """ Stupid but nice dots function for multiple matrix multiply """ + """ dots function for multiple matrix multiply """ return reduce(np.dot,args)
\ No newline at end of file |