diff options
Diffstat (limited to 'ot')
-rw-r--r-- | ot/__init__.py | 2 | ||||
-rw-r--r-- | ot/datasets.py | 17 | ||||
-rw-r--r-- | ot/plot.py | 10 | ||||
-rw-r--r-- | ot/utils.py | 5 |
4 files changed, 31 insertions, 3 deletions
diff --git a/ot/__init__.py b/ot/__init__.py index 24350a5..0a9e89b 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -12,6 +12,6 @@ from emd import emd from bregman import sinkhorn # utils functions -from utils import dist,dots +from utils import dist,dots,unif __all__ = ["emd","sinkhorn","utils",'datasets','plot','dist','dots'] diff --git a/ot/datasets.py b/ot/datasets.py index bb10ba4..3ebc2a1 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -1,10 +1,23 @@ import numpy as np - +import scipy as sp def get_1D_gauss(n,m,s): "return a 1D histogram for a gaussian distribution (n bins, mean m and std s) " x=np.arange(n,dtype=np.float64) h=np.exp(-(x-m)**2/(2*s^2)) - return h/h.sum()
\ No newline at end of file + return h/h.sum() + + +def get_2D_samples_gauss(n,m,sigma): + "return samples from 2D gaussian (n samples, mean m and cov sigma) " + if np.isscalar(sigma): + sigma=np.array([sigma,]) + if len(sigma)>1: + P=sp.linalg.sqrtm(sigma) + res= np.random.randn(n,2).dot(P)+m + else: + res= np.random.randn(n,2)*np.sqrt(sigma)+m + return res +
\ No newline at end of file @@ -38,3 +38,13 @@ def otplot1D(a,b,M,title=''): pl.xlim((0,nb)) +def otplot2D_samples(xs,xt,G,thr=1e-8,**kwargs): + """ Plot matrix M in 2D with lines using alpha values""" + if ('color' not in kwargs) and ('c' not in kwargs): + kwargs['color']='k' + mx=G.max() + for i in range(xs.shape[0]): + for j in range(xt.shape[0]): + if G[i,j]/mx>thr: + pl.plot([xs[i,0],xt[j,0]],[xs[i,1],xt[j,1]],alpha=G[i,j]/mx,**kwargs) +
\ No newline at end of file diff --git a/ot/utils.py b/ot/utils.py index 582c3ff..5feb4c6 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -3,6 +3,11 @@ import numpy as np from scipy.spatial.distance import cdist +def unif(n): + """ return a uniform histogram (simplex) """ + return np.ones((n,))/n + + def dist(x1,x2=None,metric='sqeuclidean'): """Compute distance between samples in x1 and x2""" if x2 is None: |