diff options
Diffstat (limited to 'ot/dr.py')
-rw-r--r-- | ot/dr.py | 8 |
1 files changed, 3 insertions, 5 deletions
@@ -9,15 +9,14 @@ from pymanopt import Problem from pymanopt.solvers import SteepestDescent, TrustRegions def dist(x1,x2): - """ Compute squared euclidean distance between samples + """ Compute squared euclidean distance between samples (autograd) """ x1p2=np.sum(np.square(x1),1) x2p2=np.sum(np.square(x2),1) return x1p2.reshape((-1,1))+x2p2.reshape((1,-1))-2*np.dot(x1,x2.T) def sinkhorn(w1,w2,M,reg,k): - """ - Simple solver for Sinkhorn algorithm with fixed number of iteration + """Sinkhorn algorithm with fixed number of iteration (autograd) """ K=np.exp(-M/reg) ui=np.ones((M.shape[0],)) @@ -29,8 +28,7 @@ def sinkhorn(w1,w2,M,reg,k): return G def split_classes(X,y): - """ - split samples in X by classes in y + """split samples in X by classes in y """ lstsclass=np.unique(y) return [X[y==i,:].astype(np.float32) for i in lstsclass] |