summaryrefslogtreecommitdiff
path: root/ot/datasets.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2016-10-24 15:36:16 +0200
committerRémi Flamary <remi.flamary@gmail.com>2016-10-24 15:36:16 +0200
commit838708a0134ab8e0e0ba7ee01bd97e7aa39b23bb (patch)
tree38131877d09373144b716cf03af5d0a2f8e240c0 /ot/datasets.py
parent04262dd46420304dd8d9bb01d37e4ceb3e44c11d (diff)
add plot and utils functions
Diffstat (limited to 'ot/datasets.py')
-rw-r--r--ot/datasets.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/ot/datasets.py b/ot/datasets.py
index bb10ba4..3ebc2a1 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -1,10 +1,23 @@
import numpy as np
-
+import scipy as sp
def get_1D_gauss(n,m,s):
"return a 1D histogram for a gaussian distribution (n bins, mean m and std s) "
x=np.arange(n,dtype=np.float64)
h=np.exp(-(x-m)**2/(2*s^2))
- return h/h.sum() \ No newline at end of file
+ return h/h.sum()
+
+
+def get_2D_samples_gauss(n,m,sigma):
+ "return samples from 2D gaussian (n samples, mean m and cov sigma) "
+ if np.isscalar(sigma):
+ sigma=np.array([sigma,])
+ if len(sigma)>1:
+ P=sp.linalg.sqrtm(sigma)
+ res= np.random.randn(n,2).dot(P)+m
+ else:
+ res= np.random.randn(n,2)*np.sqrt(sigma)+m
+ return res
+ \ No newline at end of file