summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 14a92c1..9187b57 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -9,15 +9,14 @@ from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1,x2):
- """ Compute squared euclidean distance between samples
+ """ Compute squared euclidean distance between samples (autograd)
"""
x1p2=np.sum(np.square(x1),1)
x2p2=np.sum(np.square(x2),1)
return x1p2.reshape((-1,1))+x2p2.reshape((1,-1))-2*np.dot(x1,x2.T)
def sinkhorn(w1,w2,M,reg,k):
- """
- Simple solver for Sinkhorn algorithm with fixed number of iteration
+ """Sinkhorn algorithm with fixed number of iteration (autograd)
"""
K=np.exp(-M/reg)
ui=np.ones((M.shape[0],))
@@ -29,8 +28,7 @@ def sinkhorn(w1,w2,M,reg,k):
return G
def split_classes(X,y):
- """
- split samples in X by classes in y
+ """split samples in X by classes in y
"""
lstsclass=np.unique(y)
return [X[y==i,:].astype(np.float32) for i in lstsclass]