summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2023-03-09 09:58:48 +0100
committerGitHub <noreply@github.com>2023-03-09 09:58:48 +0100
commit263a36ff627257422d7e191f6882fb1c8fc68326 (patch)
treea78a4155ab74655ea24a81d822c2cddbdbe5b963
parenta6d5d75c6ca584ab9736b528810b3595f2571d82 (diff)
[MRG] Update pymanopt requirement and API for `ot.dr` (#443)
* updayte pymanopt API step 1 * add realease information * update requireents for tests on windows
-rw-r--r--.github/requirements_test_windows.txt3
-rw-r--r--RELEASES.md2
-rw-r--r--docs/requirements_rtd.txt3
-rw-r--r--ot/dr.py30
-rw-r--r--requirements.txt3
5 files changed, 21 insertions, 20 deletions
diff --git a/.github/requirements_test_windows.txt b/.github/requirements_test_windows.txt
index b94392f..448a7fc 100644
--- a/.github/requirements_test_windows.txt
+++ b/.github/requirements_test_windows.txt
@@ -3,8 +3,7 @@ scipy>=1.3
cython
matplotlib
autograd
-pymanopt==0.2.4; python_version <'3'
-pymanopt==0.2.6rc1; python_version >= '3'
+pymanopt
cvxopt
scikit-learn
pytest \ No newline at end of file
diff --git a/RELEASES.md b/RELEASES.md
index e251c30..bf2ce2e 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -15,6 +15,8 @@
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
+- `ot.dr` now uses the new Pymanopt API and POT is compatible with current
+ Pymanopt (PR #443)
#### Closed issues
diff --git a/docs/requirements_rtd.txt b/docs/requirements_rtd.txt
index 11957fb..30082bb 100644
--- a/docs/requirements_rtd.txt
+++ b/docs/requirements_rtd.txt
@@ -9,7 +9,6 @@ scipy>=1.0
cython
matplotlib
autograd
-pymanopt==0.2.4; python_version <'3'
-pymanopt; python_version >= '3'
+pymanopt
cvxopt
scikit-learn \ No newline at end of file
diff --git a/ot/dr.py b/ot/dr.py
index 0955c55..1b97841 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -17,10 +17,10 @@ Dimension reduction with OT
from scipy import linalg
import autograd.numpy as np
-from pymanopt.function import Autograd
-from pymanopt.manifolds import Stiefel
-from pymanopt import Problem
-from pymanopt.solvers import SteepestDescent, TrustRegions
+
+import pymanopt
+import pymanopt.manifolds
+import pymanopt.optimizers
def dist(x1, x2):
@@ -38,8 +38,8 @@ def sinkhorn(w1, w2, M, reg, k):
ui = np.ones((M.shape[0],))
vi = np.ones((M.shape[1],))
for i in range(k):
- vi = w2 / (np.dot(K.T, ui))
- ui = w1 / (np.dot(K, vi))
+ vi = w2 / (np.dot(K.T, ui) + 1e-50)
+ ui = w1 / (np.dot(K, vi) + 1e-50)
G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1]))
return G
@@ -222,7 +222,9 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
else:
regmean = np.ones((len(xc), len(xc)))
- @Autograd
+ manifold = pymanopt.manifolds.Stiefel(d, p)
+
+ @pymanopt.function.autograd(manifold)
def cost(P):
# wda loss
loss_b = 0
@@ -243,21 +245,21 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter
return loss_w / loss_b
# declare manifold and problem
- manifold = Stiefel(d, p)
- problem = Problem(manifold=manifold, cost=cost)
+
+ problem = pymanopt.Problem(manifold=manifold, cost=cost)
# declare solver and solve
if solver is None:
- solver = SteepestDescent(maxiter=maxiter, logverbosity=verbose)
+ solver = pymanopt.optimizers.SteepestDescent(max_iterations=maxiter, log_verbosity=verbose)
elif solver in ['tr', 'TrustRegions']:
- solver = TrustRegions(maxiter=maxiter, logverbosity=verbose)
+ solver = pymanopt.optimizers.TrustRegions(max_iterations=maxiter, log_verbosity=verbose)
- Popt = solver.solve(problem, x=P0)
+ Popt = solver.run(problem, initial_point=P0)
def proj(X):
- return (X - mx.reshape((1, -1))).dot(Popt)
+ return (X - mx.reshape((1, -1))).dot(Popt.point)
- return Popt, proj
+ return Popt.point, proj
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
diff --git a/requirements.txt b/requirements.txt
index 7cbb29a..9be4deb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,8 +2,7 @@ numpy>=1.20
scipy>=1.3
matplotlib
autograd
-pymanopt==0.2.4; python_version <'3'
-pymanopt==0.2.6rc1; python_version >= '3'
+pymanopt
cvxopt
scikit-learn
torch