summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py32
1 files changed, 17 insertions, 15 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 0955c55..b92cd14 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -17,10 +17,10 @@ Dimension reduction with OT
from scipy import linalg
import autograd.numpy as np
-from pymanopt.function import Autograd
-from pymanopt.manifolds import Stiefel
-from pymanopt import Problem
-from pymanopt.solvers import SteepestDescent, TrustRegions
+
+import pymanopt
+import pymanopt.manifolds
+import pymanopt.optimizers
def dist(x1, x2):
@@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k):
ui = np.ones((M.shape[0],))
vi = np.ones((M.shape[1],))
for i in range(k):
- vi = w2 / (np.dot(K.T, ui))
- ui = w1 / (np.dot(K, vi))
+ vi = w2 / (np.dot(K.T, ui) + 1e-50)
+ ui = w1 / (np.dot(K, vi) + 1e-50)
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
return G
@@ -167,7 +167,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
Size of dimensionnality reduction.
reg : float, optional
Regularization term >0 (entropic regularization)
- solver : None | str, optional
+ solver : None | str, optional
None for steepest descent or 'TrustRegions' for trust regions algorithm
else should be a pymanopt.solvers
sinkhorn_method : str
@@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
else:
regmean = np.ones((len(xc), len(xc)))
- @Autograd
+ manifold = pymanopt.manifolds.Stiefel(d, p)
+
+ @pymanopt.function.autograd(manifold)
def cost(P):
# wda loss
loss_b = 0
@@ -243,21 +245,21 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
return loss_w / loss_b
# declare manifold and problem
- manifold = Stiefel(d, p)
- problem = Problem(manifold=manifold, cost=cost)
+
+ problem = pymanopt.Problem(manifold=manifold, cost=cost)
# declare solver and solve
if solver is None:
- solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
+ solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose)
elif solver in ['tr', 'TrustRegions']:
- solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
+ solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose)
- Popt = solver.solve(problem, x=P0)
+ Popt = solver.run(problem, initial_point=P0)
def proj(X):
- return (X - mx.reshape((1, -1))).dot(Popt)
+ return (X - mx.reshape((1, -1))).dot(Popt.point)
- return Popt, proj
+ return Popt.point, proj
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):