From 263a36ff627257422d7e191f6882fb1c8fc68326 Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 9 Mar 2023 09:58:48 +0100 Subject: [MRG] Update pymanopt requirement and API for `ot.dr` (#443) * updayte pymanopt API step 1 * add realease information * update requireents for tests on windows --- ot/dr.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) (limited to 'ot') diff --git a/ot/dr.py b/ot/dr.py index 0955c55..1b97841 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -17,10 +17,10 @@ Dimension reduction with OT from scipy import linalg import autograd.numpy as np -from pymanopt.function import Autograd -from pymanopt.manifolds import Stiefel -from pymanopt import Problem -from pymanopt.solvers import SteepestDescent, TrustRegions + +import pymanopt +import pymanopt.manifolds +import pymanopt.optimizers def dist(x1, x2): @@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k): ui = np.ones((M.shape[0],)) vi = np.ones((M.shape[1],)) for i in range(k): - vi = w2 / (np.dot(K.T, ui)) - ui = w1 / (np.dot(K, vi)) + vi = w2 / (np.dot(K.T, ui) + 1e-50) + ui = w1 / (np.dot(K, vi) + 1e-50) G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1])) return G @@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter else: regmean = np.ones((len(xc), len(xc))) - @Autograd + manifold = pymanopt.manifolds.Stiefel(d, p) + + @pymanopt.function.autograd(manifold) def cost(P): # wda loss loss_b = 0 @@ -243,21 +245,21 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter return loss_w / loss_b # declare manifold and problem - manifold = Stiefel(d, p) - problem = Problem(manifold=manifold, cost=cost) + + problem = pymanopt.Problem(manifold=manifold, cost=cost) # declare solver and solve if solver is None: - solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose) + solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose) elif solver in ['tr', 'TrustRegions']: - solver = TrustRegions(maxiter=maxiter, logverbosity=verbose) + solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose) - Popt = solver.solve(problem, x=P0) + Popt = solver.run(problem, initial_point=P0) def proj(X): - return (X - mx.reshape((1, -1))).dot(Popt) + return (X - mx.reshape((1, -1))).dot(Popt.point) - return Popt, proj + return Popt.point, proj def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): -- cgit v1.2.3