summaryrefslogtreecommitdiff
path: root/test/test_optim.py
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-03-24 10:53:47 +0100
committerGitHub <noreply@github.com>2022-03-24 10:53:47 +0100
commit767171593f2a98a26b9a39bf110a45085e3b982e (patch)
tree4eb4bcc657efc53a65c3fb4439bd0e0e106b6745 /test/test_optim.py
parent9b9d2221d257f40ea3eb58b279b30d69162d62bb (diff)
[MRG] Domain adaptation and unbalanced solvers with backend support (#343)
* First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_optim.py')
-rw-r--r--test/test_optim.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/test/test_optim.py b/test/test_optim.py
index 41f9cbe..67e9d13 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -32,9 +32,7 @@ def test_conditional_gradient(nx):
def fb(G):
return 0.5 * nx.sum(G ** 2)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ ab, bb, Mb = nx.from_numpy(a, b, M)
reg = 1e-1
@@ -74,9 +72,7 @@ def test_conditional_gradient_itermax(nx):
def fb(G):
return 0.5 * nx.sum(G ** 2)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ ab, bb, Mb = nx.from_numpy(a, b, M)
reg = 1e-1
@@ -118,9 +114,7 @@ def test_generalized_conditional_gradient(nx):
reg1 = 1e-3
reg2 = 1e-1
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ ab, bb, Mb = nx.from_numpy(a, b, M)
G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
@@ -142,9 +136,12 @@ def test_line_search_armijo(nx):
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
old_fval = -123
+
+ xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk)
+
# Should not throw an exception and return 0. for alpha
alpha, a, b = ot.optim.line_search_armijo(
- lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
+ lambda x: 1, xkb, pkb, gfkb, old_fval
)
alpha_np, anp, bnp = ot.optim.line_search_armijo(
lambda x: 1, xk, pk, gfk, old_fval