summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/demo_OT_1D.py6
-rw-r--r--ot/__init__.py2
-rw-r--r--ot/datasets.py17
-rw-r--r--ot/plot.py10
-rw-r--r--ot/utils.py5
5 files changed, 32 insertions, 8 deletions
diff --git a/examples/demo_OT_1D.py b/examples/demo_OT_1D.py
index b17f902..accf722 100644
--- a/examples/demo_OT_1D.py
+++ b/examples/demo_OT_1D.py
@@ -7,7 +7,6 @@ Created on Fri Oct 21 09:51:45 2016
import numpy as np
import matplotlib.pylab as pl
-
import ot
@@ -30,10 +29,8 @@ M/=M.max()
#%% plot the distributions
pl.figure(1)
-
pl.plot(x,a,'b',label='Source distribution')
pl.plot(x,b,'r',label='Target distribution')
-
pl.legend()
#%% plot distributions and loss matrix
@@ -41,7 +38,6 @@ pl.legend()
pl.figure(2)
ot.plot.otplot1D(a,b,M,'Cost matrix M')
-
#%% EMD
G0=ot.emd(a,b,M)
@@ -50,8 +46,8 @@ pl.figure(3)
ot.plot.otplot1D(a,b,G0,'OT matrix G0')
#%% Sinkhorn
-lambd=1e-3
+lambd=1e-3
Gs=ot.sinkhorn(a,b,M,lambd)
pl.figure(4)
diff --git a/ot/__init__.py b/ot/__init__.py
index 24350a5..0a9e89b 100644
--- a/ot/__init__.py
+++ b/ot/__init__.py
@@ -12,6 +12,6 @@ from emd import emd
from bregman import sinkhorn
# utils functions
-from utils import dist,dots
+from utils import dist,dots,unif
__all__ = ["emd","sinkhorn","utils",'datasets','plot','dist','dots']
diff --git a/ot/datasets.py b/ot/datasets.py
index bb10ba4..3ebc2a1 100644
--- a/ot/datasets.py
+++ b/ot/datasets.py
@@ -1,10 +1,23 @@
import numpy as np
-
+import scipy as sp
def get_1D_gauss(n,m,s):
"return a 1D histogram for a gaussian distribution (n bins, mean m and std s) "
x=np.arange(n,dtype=np.float64)
h=np.exp(-(x-m)**2/(2*s^2))
- return h/h.sum() \ No newline at end of file
+ return h/h.sum()
+
+
+def get_2D_samples_gauss(n,m,sigma):
+ "return samples from 2D gaussian (n samples, mean m and cov sigma) "
+ if np.isscalar(sigma):
+ sigma=np.array([sigma,])
+ if len(sigma)>1:
+ P=sp.linalg.sqrtm(sigma)
+ res= np.random.randn(n,2).dot(P)+m
+ else:
+ res= np.random.randn(n,2)*np.sqrt(sigma)+m
+ return res
+ \ No newline at end of file
diff --git a/ot/plot.py b/ot/plot.py
index 743ab4a..f78daf6 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -38,3 +38,13 @@ def otplot1D(a,b,M,title=''):
pl.xlim((0,nb))
+def otplot2D_samples(xs,xt,G,thr=1e-8,**kwargs):
+ """ Plot matrix M in 2D with lines using alpha values"""
+ if ('color' not in kwargs) and ('c' not in kwargs):
+ kwargs['color']='k'
+ mx=G.max()
+ for i in range(xs.shape[0]):
+ for j in range(xt.shape[0]):
+ if G[i,j]/mx>thr:
+ pl.plot([xs[i,0],xt[j,0]],[xs[i,1],xt[j,1]],alpha=G[i,j]/mx,**kwargs)
+ \ No newline at end of file
diff --git a/ot/utils.py b/ot/utils.py
index 582c3ff..5feb4c6 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -3,6 +3,11 @@ import numpy as np
from scipy.spatial.distance import cdist
+def unif(n):
+ """ return a uniform histogram (simplex) """
+ return np.ones((n,))/n
+
+
def dist(x1,x2=None,metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2"""
if x2 is None: