summaryrefslogtreecommitdiff
path: root/ot
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-28 12:50:54 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-28 12:50:54 +0200
commit9523b1e95ed7c117554ff673532d150841092137 (patch)
tree1ed5b5060b3b59e79ca2e286058162cdd51c3b75 /ot
parentf33087d2b1790dac773782bb0d91bcfe7ce6a079 (diff)
doc datasets.py
Diffstat (limited to 'ot')
-rw-r--r--ot/bregman.py14
-rw-r--r--ot/datasets.py71
-rw-r--r--ot/utils.py23
3 files changed, 83 insertions, 25 deletions
diff --git a/ot/bregman.py b/ot/bregman.py
index 9183bea..ad9a67a 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -276,14 +276,6 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal
The optimization problem is solved suing the algorithm described in [4]
-
- distrib : distribution to unmix
- D : Dictionnary
- M : Metric matrix in the space of the distributions to unmix
- M0 : Metric matrix in the space of the 'abundance values' to solve for
- h0 : prior on solution (generally uniform distribution)
- reg,reg0 : transport regularizations
- alpha : how much should we trust the prior ? ([0,1])
Parameters
----------
@@ -300,7 +292,9 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal
reg: float
Regularization term >0 (Wasserstein data fitting)
reg0: float
- Regularization term >0 (Wasserstein reg with h0)
+ Regularization term >0 (Wasserstein reg with h0)
+ alpha: float
+ How much should we trust the prior ([0,1])
numItermax: int, optional
Max number of iterations
stopThr: float, optional
@@ -318,7 +312,7 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal
log: dict
log dictionary return only if log==True in parameters
- References
+ References
----------
.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.
diff --git a/ot/datasets.py b/ot/datasets.py
index f22e345..6388d94 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -8,14 +8,50 @@ import scipy as sp
def get_1D_gauss(n,m,s):
- "return a 1D histogram for a gaussian distribution (n bins, mean m and std s) "
+ """return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
+
+ Parameters
+ ----------
+
+ n : int
+ number of bins in the histogram
+ m : float
+ mean value of the gaussian distribution
+ s : float
+ standard deviaton of the gaussian distribution
+
+
+ Returns
+ -------
+ h : np.array (n,)
+ 1D histogram for a gaussian distribution
+
+ """
x=np.arange(n,dtype=np.float64)
h=np.exp(-(x-m)**2/(2*s^2))
return h/h.sum()
def get_2D_samples_gauss(n,m,sigma):
- "return samples from 2D gaussian (n samples, mean m and cov sigma) "
+ """return n samples drawn from 2D gaussian N(m,sigma)
+
+ Parameters
+ ----------
+
+ n : int
+ number of bins in the histogram
+ m : np.array (2,)
+ mean value of the gaussian distribution
+ sigma : np.array (2,2)
+ covariance matrix of the gaussian distribution
+
+
+ Returns
+ -------
+ X : np.array (n,2)
+ n samples drawn from N(m,sigma)
+
+ """
if np.isscalar(sigma):
sigma=np.array([sigma,])
if len(sigma)>1:
@@ -26,8 +62,26 @@ def get_2D_samples_gauss(n,m,sigma):
return res
def get_data_classif(dataset,n,nz=.5,**kwargs):
- """
- dataset generation
+ """ dataset generation for classification problems
+
+ Parameters
+ ----------
+
+ dataset : str
+ type of classification problem (see code)
+ n : int
+ number of training samples
+ nz : float
+ noise level (>0)
+
+
+ Returns
+ -------
+ X : np.array (n,d)
+ n observation of size d
+ y : np.array (n,)
+ labels of the samples
+
"""
if dataset.lower()=='3gauss':
y=np.floor((np.arange(n)*1.0/n*3))+1
@@ -50,15 +104,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
x[y==3,0]=2. ; x[y==3,1]=0
x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
- x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
-# elif dataset.lower()=='sinreg':
-#
-# x=np.random.rand(n,1)
-# y=4*x+np.sin(2*np.pi*x)+nz*np.random.randn(n,1)
-
+ x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
else:
x=0
y=0
print("unknown dataset")
- return x,y \ No newline at end of file
+ return x,y.astype(int) \ No newline at end of file
diff --git a/ot/utils.py b/ot/utils.py
index e5ec864..2110c01 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
"""
Various function that can be usefull
"""
@@ -6,7 +7,21 @@ from scipy.spatial.distance import cdist
def unif(n):
- """ return a uniform histogram of length n (simplex) """
+ """ return a uniform histogram of length n (simplex)
+
+ Parameters
+ ----------
+
+ n : int
+ number of bins in the histogram
+
+ Returns
+ -------
+ h : np.array (n,)
+ histogram of length n such that h_i=1/n for all i
+
+
+ """
return np.ones((n,))/n
@@ -22,9 +37,9 @@ def dist(x1,x2=None,metric='sqeuclidean'):
matrix with n2 samples of size d (if None then x2=x1)
metric : str, fun, optional
name of the metric to be computed (full list in the doc of scipy), If a string,
- the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’,
+ the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’,
‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’,
- ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’,
+ ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’,
‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’.
@@ -68,5 +83,5 @@ def dist0(n,method='lin_square'):
def dots(*args):
- """ Stupid but nice dots function for multiple matrix multiply """
+ """ dots function for multiple matrix multiply """
return reduce(np.dot,args) \ No newline at end of file