summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-07-03 16:45:59 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-07-03 16:45:59 +0200
commit3ba85d6c8a1dac2722c5f51d4b2c49e3ed2ebd0f (patch)
treecdd567c2675d4eb3e410ef4549fb418ed8a7fe47 /ot/dr.py
parent8ff51eb898c3cb112d16a4a629639eddcef62516 (diff)
doc update
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 14a92c1..9187b57 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -9,15 +9,14 @@ from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1,x2):
- """ Compute squared euclidean distance between samples
+ """ Compute squared euclidean distance between samples (autograd)
"""
x1p2=np.sum(np.square(x1),1)
x2p2=np.sum(np.square(x2),1)
return x1p2.reshape((-1,1))+x2p2.reshape((1,-1))-2*np.dot(x1,x2.T)
def sinkhorn(w1,w2,M,reg,k):
- """
- Simple solver for Sinkhorn algorithm with fixed number of iteration
+ """Sinkhorn algorithm with fixed number of iteration (autograd)
"""
K=np.exp(-M/reg)
ui=np.ones((M.shape[0],))
@@ -29,8 +28,7 @@ def sinkhorn(w1,w2,M,reg,k):
return G
def split_classes(X,y):
- """
- split samples in X by classes in y
+ """split samples in X by classes in y
"""
lstsclass=np.unique(y)
return [X[y==i,:].astype(np.float32) for i in lstsclass]