diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2022-03-24 10:53:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-24 10:53:47 +0100 |
commit | 767171593f2a98a26b9a39bf110a45085e3b982e (patch) | |
tree | 4eb4bcc657efc53a65c3fb4439bd0e0e106b6745 /test/test_ot.py | |
parent | 9b9d2221d257f40ea3eb58b279b30d69162d62bb (diff) |
[MRG] Domain adaptation and unbalanced solvers with backend support (#343)
* First draft
* Add matrix inverse and square root to backend
* Eigen decomposition for older versions of pytorch (1.8.1 and older)
* Corrected eigen decomposition for pytorch 1.8.1 and older
* Spectral theorem is a thing
* Optimization
* small optimization
* More functions converted
* pep8
* remove a warning and prepare torch meshgrid for future torch release (which will change default indexing)
* dots and pep8
* Meshgrid corrected for older version and prepared for future versions changes
* New backend functions
* Base transport
* LinearTransport
* All transport classes + pep8
* PR added to release file
* Jcpot barycenter test
* unbalanced with backend
* pep8
* bug solve
* test of domain adaptation with backends
* solve bug for tic toc & macos
* solving scipy deprecation warning
* solving scipy deprecation warning attempt2
* solving scipy deprecation warning attempt3
* A warning is triggered when a float->int conversion is detected
* bug solve
* docs
* release file updated
* Better handling of float->int conversion in EMD
* Corrected test for is_floating_point
* docs
* release file updated
* cupy does not allow implicit cast
* fromnumpy
* added test
* test da tf jax
* test unbalanced with no provided histogram
* using type_as argument in unif function correctly
* pep8
* transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_ot.py')
-rw-r--r-- | test/test_ot.py | 19 |
1 files changed, 7 insertions, 12 deletions
diff --git a/test/test_ot.py b/test/test_ot.py index 3e2d845..bb258e2 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -47,8 +47,7 @@ def test_emd_backends(nx): G = ot.emd(a, a, M) - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) @@ -68,8 +67,7 @@ def test_emd2_backends(nx): val = ot.emd2(a, a, M) - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) valb = ot.emd2(ab, ab, Mb) @@ -90,8 +88,7 @@ def test_emd_emd2_types_devices(nx): for tp in nx.__type_list__: print(nx.dtype_device(tp)) - ab = nx.from_numpy(a, type_as=tp) - Mb = nx.from_numpy(M, type_as=tp) + ab, Mb = nx.from_numpy(a, M, type_as=tp) Gb = ot.emd(ab, ab, Mb) @@ -117,8 +114,7 @@ def test_emd_emd2_devices_tf(): # Check that everything stays on the CPU with tf.device("/CPU:0"): - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) w = ot.emd2(ab, ab, Mb) nx.assert_same_dtype_device(Mb, Gb) @@ -126,8 +122,7 @@ def test_emd_emd2_devices_tf(): if len(tf.config.list_physical_devices('GPU')) > 0: # Check that everything happens on the GPU - ab = nx.from_numpy(a) - Mb = nx.from_numpy(M) + ab, Mb = nx.from_numpy(a, M) Gb = ot.emd(ab, ab, Mb) w = ot.emd2(ab, ab, Mb) nx.assert_same_dtype_device(Mb, Gb) @@ -310,8 +305,8 @@ def test_free_support_barycenter_backends(nx): X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init) - measures_locations2 = [nx.from_numpy(x) for x in measures_locations] - measures_weights2 = [nx.from_numpy(x) for x in measures_weights] + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) X_init2 = nx.from_numpy(X_init) X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2) |