From 767171593f2a98a26b9a39bf110a45085e3b982e Mon Sep 17 00:00:00 2001 From: Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> Date: Thu, 24 Mar 2022 10:53:47 +0100 Subject: [MRG] Domain adaptation and unbalanced solvers with backend support (#343) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: RĂ©mi Flamary --- test/test_weak.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'test/test_weak.py') diff --git a/test/test_weak.py b/test/test_weak.py index c4c3278..945efb1 100644 --- a/test/test_weak.py +++ b/test/test_weak.py @@ -45,9 +45,7 @@ def test_weak_ot_bakends(nx): G = ot.weak_optimal_transport(xs, xt, u, u) - xs2 = nx.from_numpy(xs) - xt2 = nx.from_numpy(xt) - u2 = nx.from_numpy(u) + xs2, xt2, u2 = nx.from_numpy(xs, xt, u) G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2) -- cgit v1.2.3