summaryrefslogtreecommitdiff
path: root/test/test_sliced.py
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-03-24 10:53:47 +0100
committerGitHub <noreply@github.com>2022-03-24 10:53:47 +0100
commit767171593f2a98a26b9a39bf110a45085e3b982e (patch)
tree4eb4bcc657efc53a65c3fb4439bd0e0e106b6745 /test/test_sliced.py
parent9b9d2221d257f40ea3eb58b279b30d69162d62bb (diff)
[MRG] Domain adaptation and unbalanced solvers with backend support (#343)
* First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'test/test_sliced.py')
-rw-r--r--test/test_sliced.py32
1 files changed, 8 insertions, 24 deletions
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 91e0961..08ab4fb 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -123,9 +123,7 @@ def test_sliced_backend(nx):
n_projections = 20
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
val0 = ot.sliced_wasserstein_distance(x, y, projections=P)
@@ -153,9 +151,7 @@ def test_sliced_backend_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- yb = nx.from_numpy(y, type_as=tp)
- Pb = nx.from_numpy(P, type_as=tp)
+ xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
@@ -174,17 +170,13 @@ def test_sliced_backend_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
@@ -203,9 +195,7 @@ def test_max_sliced_backend(nx):
n_projections = 20
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P)
@@ -233,9 +223,7 @@ def test_max_sliced_backend_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- yb = nx.from_numpy(y, type_as=tp)
- Pb = nx.from_numpy(P, type_as=tp)
+ xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
@@ -254,17 +242,13 @@ def test_max_sliced_backend_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")