diff options
author | RĂ©mi Flamary <remi.flamary@gmail.com> | 2023-03-09 09:58:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-09 09:58:48 +0100 |
commit | 263a36ff627257422d7e191f6882fb1c8fc68326 (patch) | |
tree | a78a4155ab74655ea24a81d822c2cddbdbe5b963 /ot | |
parent | a6d5d75c6ca584ab9736b528810b3595f2571d82 (diff) |
[MRG] Update pymanopt requirement and API for `ot.dr` (#443)
* updayte pymanopt API step 1
* add realease information
* update requireents for tests on windows
Diffstat (limited to 'ot')
-rw-r--r-- | ot/dr.py | 30 |
1 files changed, 16 insertions, 14 deletions
@@ -17,10 +17,10 @@ Dimension reduction with OT from scipy import linalg import autograd.numpy as np -from pymanopt.function import Autograd -from pymanopt.manifolds import Stiefel -from pymanopt import Problem -from pymanopt.solvers import SteepestDescent, TrustRegions + +import pymanopt +import pymanopt.manifolds +import pymanopt.optimizers def dist(x1, x2): @@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k): ui = np.ones((M.shape[0],)) vi = np.ones((M.shape[1],)) for i in range(k): - vi = w2 / (np.dot(K.T, ui)) - ui = w1 / (np.dot(K, vi)) + vi = w2 / (np.dot(K.T, ui) + 1e-50) + ui = w1 / (np.dot(K, vi) + 1e-50) G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1])) return G @@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter else: regmean = np.ones((len(xc), len(xc))) - @Autograd + manifold = pymanopt.manifolds.Stiefel(d, p) + + @pymanopt.function.autograd(manifold) def cost(P): # wda loss loss_b = 0 @@ -243,21 +245,21 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter return loss_w / loss_b # declare manifold and problem - manifold = Stiefel(d, p) - problem = Problem(manifold=manifold, cost=cost) + + problem = pymanopt.Problem(manifold=manifold, cost=cost) # declare solver and solve if solver is None: - solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose) + solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose) elif solver in ['tr', 'TrustRegions']: - solver = TrustRegions(maxiter=maxiter, logverbosity=verbose) + solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose) - Popt = solver.solve(problem, x=P0) + Popt = solver.run(problem, initial_point=P0) def proj(X): - return (X - mx.reshape((1, -1))).dot(Popt) + return (X - mx.reshape((1, -1))).dot(Popt.point) - return Popt, proj + return Popt.point, proj def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): |