summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/ot/dr.py b/ot/dr.py
index c2f51f8..1671ca0 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -16,6 +16,7 @@ Dimension reduction with OT
from scipy import linalg
import autograd.numpy as np
+from pymanopt.function import Autograd
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
@@ -181,6 +182,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no
else:
regmean = np.ones((len(xc), len(xc)))
+ @Autograd
def cost(P):
# wda loss
loss_b = 0