summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2021-12-03 12:37:05 +0100
committerGitHub <noreply@github.com>2021-12-03 12:37:05 +0100
commitca69658400dc2ef6a7d3e531acffcd107400085f (patch)
treeb77a28821067be5240cec2082fa1f119b1cfd1cd
parentcb510644b2fd65e4ce216a7799ce7401f71548b8 (diff)
[MRG] Cupy backend (#315)
* Cupy backend * pep8 + bug * working even if cupy not installed * attempt to force codecov to ignore cupy because no gpu can be used for testing on github * docstring * readme
-rw-r--r--README.md1
-rw-r--r--ot/backend.py302
-rw-r--r--ot/gromov.py35
-rw-r--r--ot/optim.py2
-rw-r--r--test/test_backend.py21
-rw-r--r--test/test_bregman.py1
-rw-r--r--test/test_gromov.py19
7 files changed, 355 insertions, 26 deletions
diff --git a/README.md b/README.md
index ff8056a..18064a3 100644
--- a/README.md
+++ b/README.md
@@ -196,6 +196,7 @@ The contributors to this library are
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
+* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
diff --git a/ot/backend.py b/ot/backend.py
index fa164c3..1630ac4 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -3,7 +3,7 @@
Multi-lib backend for POT
The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch,
-or Jax, POT code should work nonetheless.
+Jax, or Cupy, POT code should work nonetheless.
To achieve that, POT provides backend classes which implements functions in their respective backend
imitating Numpy API. As a convention, we use nx instead of np to refer to the backend.
@@ -44,6 +44,14 @@ except ImportError:
jax = False
jax_type = float
+try:
+ import cupy as cp
+ import cupyx
+ cp_type = cp.ndarray
+except ImportError:
+ cp = False
+ cp_type = float
+
str_type_error = "All array should be from the same type/backend. Current types are : {}"
@@ -57,6 +65,9 @@ def get_backend_list():
if jax:
lst.append(JaxBackend())
+ if cp:
+ lst.append(CupyBackend())
+
return lst
@@ -78,6 +89,8 @@ def get_backend(*args):
return TorchBackend()
elif isinstance(args[0], jax_type):
return JaxBackend()
+ elif isinstance(args[0], cp_type):
+ return CupyBackend()
else:
raise ValueError("Unknown type of non implemented backend.")
@@ -94,7 +107,8 @@ def to_numpy(*args):
class Backend():
"""
Backend abstract class.
- Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`
+ Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`,
+ :py:class:`CupyBackend`
- The `__name__` class attribute refers to the name of the backend.
- The `__type__` class attribute refers to the data structure used by the backend.
@@ -1500,3 +1514,287 @@ class TorchBackend(Backend):
assert a_dtype == b_dtype, "Dtype discrepancy"
assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
+
+
+class CupyBackend(Backend): # pragma: no cover
+ """
+ CuPy implementation of the backend
+
+ - `__name__` is "cupy"
+ - `__type__` is cp.ndarray
+ """
+
+ __name__ = 'cupy'
+ __type__ = cp_type
+ __type_list__ = None
+
+ rng_ = None
+
+ def __init__(self):
+ self.rng_ = cp.random.RandomState()
+
+ self.__type_list__ = [
+ cp.array(1, dtype=cp.float32),
+ cp.array(1, dtype=cp.float64)
+ ]
+
+ def to_numpy(self, a):
+ return cp.asnumpy(a)
+
+ def from_numpy(self, a, type_as=None):
+ if type_as is None:
+ return cp.asarray(a)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.asarray(a, dtype=type_as.dtype)
+
+ def set_gradients(self, val, inputs, grads):
+ # No gradients for cupy
+ return val
+
+ def zeros(self, shape, type_as=None):
+ if isinstance(shape, (list, tuple)):
+ shape = tuple(int(i) for i in shape)
+ if type_as is None:
+ return cp.zeros(shape)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.zeros(shape, dtype=type_as.dtype)
+
+ def ones(self, shape, type_as=None):
+ if isinstance(shape, (list, tuple)):
+ shape = tuple(int(i) for i in shape)
+ if type_as is None:
+ return cp.ones(shape)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.ones(shape, dtype=type_as.dtype)
+
+ def arange(self, stop, start=0, step=1, type_as=None):
+ return cp.arange(start, stop, step)
+
+ def full(self, shape, fill_value, type_as=None):
+ if isinstance(shape, (list, tuple)):
+ shape = tuple(int(i) for i in shape)
+ if type_as is None:
+ return cp.full(shape, fill_value)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.full(shape, fill_value, dtype=type_as.dtype)
+
+ def eye(self, N, M=None, type_as=None):
+ if type_as is None:
+ return cp.eye(N, M)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cp.eye(N, M, dtype=type_as.dtype)
+
+ def sum(self, a, axis=None, keepdims=False):
+ return cp.sum(a, axis, keepdims=keepdims)
+
+ def cumsum(self, a, axis=None):
+ return cp.cumsum(a, axis)
+
+ def max(self, a, axis=None, keepdims=False):
+ return cp.max(a, axis, keepdims=keepdims)
+
+ def min(self, a, axis=None, keepdims=False):
+ return cp.min(a, axis, keepdims=keepdims)
+
+ def maximum(self, a, b):
+ return cp.maximum(a, b)
+
+ def minimum(self, a, b):
+ return cp.minimum(a, b)
+
+ def abs(self, a):
+ return cp.abs(a)
+
+ def exp(self, a):
+ return cp.exp(a)
+
+ def log(self, a):
+ return cp.log(a)
+
+ def sqrt(self, a):
+ return cp.sqrt(a)
+
+ def power(self, a, exponents):
+ return cp.power(a, exponents)
+
+ def dot(self, a, b):
+ return cp.dot(a, b)
+
+ def norm(self, a):
+ return cp.sqrt(cp.sum(cp.square(a)))
+
+ def any(self, a):
+ return cp.any(a)
+
+ def isnan(self, a):
+ return cp.isnan(a)
+
+ def isinf(self, a):
+ return cp.isinf(a)
+
+ def einsum(self, subscripts, *operands):
+ return cp.einsum(subscripts, *operands)
+
+ def sort(self, a, axis=-1):
+ return cp.sort(a, axis)
+
+ def argsort(self, a, axis=-1):
+ return cp.argsort(a, axis)
+
+ def searchsorted(self, a, v, side='left'):
+ if a.ndim == 1:
+ return cp.searchsorted(a, v, side)
+ else:
+ # this is a not very efficient way to make numpy
+ # searchsorted work on 2d arrays
+ ret = cp.empty(v.shape, dtype=int)
+ for i in range(a.shape[0]):
+ ret[i, :] = cp.searchsorted(a[i, :], v[i, :], side)
+ return ret
+
+ def flip(self, a, axis=None):
+ return cp.flip(a, axis)
+
+ def outer(self, a, b):
+ return cp.outer(a, b)
+
+ def clip(self, a, a_min, a_max):
+ return cp.clip(a, a_min, a_max)
+
+ def repeat(self, a, repeats, axis=None):
+ return cp.repeat(a, repeats, axis)
+
+ def take_along_axis(self, arr, indices, axis):
+ return cp.take_along_axis(arr, indices, axis)
+
+ def concatenate(self, arrays, axis=0):
+ return cp.concatenate(arrays, axis)
+
+ def zero_pad(self, a, pad_width):
+ return cp.pad(a, pad_width)
+
+ def argmax(self, a, axis=None):
+ return cp.argmax(a, axis=axis)
+
+ def mean(self, a, axis=None):
+ return cp.mean(a, axis=axis)
+
+ def std(self, a, axis=None):
+ return cp.std(a, axis=axis)
+
+ def linspace(self, start, stop, num):
+ return cp.linspace(start, stop, num)
+
+ def meshgrid(self, a, b):
+ return cp.meshgrid(a, b)
+
+ def diag(self, a, k=0):
+ return cp.diag(a, k)
+
+ def unique(self, a):
+ return cp.unique(a)
+
+ def logsumexp(self, a, axis=None):
+ # Taken from
+ # https://github.com/scipy/scipy/blob/v1.7.1/scipy/special/_logsumexp.py#L7-L127
+ a_max = cp.amax(a, axis=axis, keepdims=True)
+
+ if a_max.ndim > 0:
+ a_max[~cp.isfinite(a_max)] = 0
+ elif not cp.isfinite(a_max):
+ a_max = 0
+
+ tmp = cp.exp(a - a_max)
+ s = cp.sum(tmp, axis=axis)
+ out = cp.log(s)
+ a_max = cp.squeeze(a_max, axis=axis)
+ out += a_max
+ return out
+
+ def stack(self, arrays, axis=0):
+ return cp.stack(arrays, axis)
+
+ def reshape(self, a, shape):
+ return cp.reshape(a, shape)
+
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_.seed(seed)
+
+ def rand(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.rand(*size)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return self.rng_.rand(*size, dtype=type_as.dtype)
+
+ def randn(self, *size, type_as=None):
+ if type_as is None:
+ return self.rng_.randn(*size)
+ else:
+ with cp.cuda.Device(type_as.device):
+ return self.rng_.randn(*size, dtype=type_as.dtype)
+
+ def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
+ data = self.from_numpy(data)
+ rows = self.from_numpy(rows)
+ cols = self.from_numpy(cols)
+ if type_as is None:
+ return cupyx.scipy.sparse.coo_matrix(
+ (data, (rows, cols)), shape=shape
+ )
+ else:
+ with cp.cuda.Device(type_as.device):
+ return cupyx.scipy.sparse.coo_matrix(
+ (data, (rows, cols)), shape=shape, dtype=type_as.dtype
+ )
+
+ def issparse(self, a):
+ return cupyx.scipy.sparse.issparse(a)
+
+ def tocsr(self, a):
+ if self.issparse(a):
+ return a.tocsr()
+ else:
+ return cupyx.scipy.sparse.csr_matrix(a)
+
+ def eliminate_zeros(self, a, threshold=0.):
+ if threshold > 0:
+ if self.issparse(a):
+ a.data[self.abs(a.data) <= threshold] = 0
+ else:
+ a[self.abs(a) <= threshold] = 0
+ if self.issparse(a):
+ a.eliminate_zeros()
+ return a
+
+ def todense(self, a):
+ if self.issparse(a):
+ return a.toarray()
+ else:
+ return a
+
+ def where(self, condition, x, y):
+ return cp.where(condition, x, y)
+
+ def copy(self, a):
+ return a.copy()
+
+ def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
+ return cp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
+
+ def dtype_device(self, a):
+ return a.dtype, a.device
+
+ def assert_same_dtype_device(self, a, b):
+ a_dtype, a_device = self.dtype_device(a)
+ b_dtype, b_device = self.dtype_device(b)
+
+ # cupy has implicit type conversion so
+ # we automatically validate the test for type
+ assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}"
diff --git a/ot/gromov.py b/ot/gromov.py
index ea667e4..2a70070 100644
--- a/ot/gromov.py
+++ b/ot/gromov.py
@@ -822,8 +822,12 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int)
- index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
- index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False)
+ index_i = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
+ index_j = generator.choice(
+ len_p, size=nb_samples_p, p=nx.to_numpy(p), replace=False
+ )
for i in range(nb_samples_p):
if nx.issparse(T):
@@ -836,13 +840,13 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T,
index_k[i] = generator.choice(
len_q,
size=nb_samples_q,
- p=T_indexi / nx.sum(T_indexi),
+ p=nx.to_numpy(T_indexi / nx.sum(T_indexi)),
replace=True
)
index_l[i] = generator.choice(
len_q,
size=nb_samples_q,
- p=T_indexj / nx.sum(T_indexj),
+ p=nx.to_numpy(T_indexj / nx.sum(T_indexj)),
replace=True
)
@@ -934,15 +938,17 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun,
index = np.zeros(2, dtype=int)
# Initialize with default marginal
- index[0] = generator.choice(len_p, size=1, p=p)
- index[1] = generator.choice(len_q, size=1, p=q)
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
+ index[1] = generator.choice(len_q, size=1, p=nx.to_numpy(q))
T = nx.tocsr(emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False))
best_gw_dist_estimated = np.inf
for cpt in range(max_iter):
- index[0] = generator.choice(len_p, size=1, p=p)
+ index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
- index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum())
+ index[1] = generator.choice(
+ len_q, size=1, p=nx.to_numpy(T_index0 / T_index0.sum())
+ )
if alpha == 1:
T = nx.tocsr(
@@ -1071,13 +1077,16 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun,
C_are_symmetric = nx.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and nx.allclose(C2, C2.T, rtol=1e-10, atol=1e-10)
for cpt in range(max_iter):
- index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False)
+ index0 = generator.choice(
+ len_p, size=nb_samples_grad_p, p=nx.to_numpy(p), replace=False
+ )
Lik = 0
for i, index0_i in enumerate(index0):
- index1 = generator.choice(len_q,
- size=nb_samples_grad_q,
- p=T[index0_i, :] / nx.sum(T[index0_i, :]),
- replace=False)
+ index1 = generator.choice(
+ len_q, size=nb_samples_grad_q,
+ p=nx.to_numpy(T[index0_i, :] / nx.sum(T[index0_i, :])),
+ replace=False
+ )
# If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly.
if (not C_are_symmetric) and generator.rand(1) > 0.5:
Lik += nx.mean(loss_fun(
diff --git a/ot/optim.py b/ot/optim.py
index 9b8a8f7..f25e2c9 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -88,7 +88,7 @@ def line_search_armijo(
else:
if alpha_min is not None or alpha_max is not None:
alpha = np.clip(alpha, alpha_min, alpha_max)
- return alpha, fc[0], phi1
+ return float(alpha), fc[0], phi1
def solve_linesearch(
diff --git a/test/test_backend.py b/test/test_backend.py
index 1832b91..2e7eecc 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -7,7 +7,7 @@
import ot
import ot.backend
-from ot.backend import torch, jax
+from ot.backend import torch, jax, cp
import pytest
@@ -87,6 +87,20 @@ def test_get_backend():
with pytest.raises(ValueError):
get_backend(A, B2)
+ if cp:
+ A2 = cp.asarray(A)
+ B2 = cp.asarray(B)
+
+ nx = get_backend(A2)
+ assert nx.__name__ == 'cupy'
+
+ nx = get_backend(A2, B2)
+ assert nx.__name__ == 'cupy'
+
+ # test not unique types in input
+ with pytest.raises(ValueError):
+ get_backend(A, B2)
+
def test_convert_between_backends(nx):
@@ -240,7 +254,7 @@ def test_func_backends(nx):
# Sparse tensors test
sp_row = np.array([0, 3, 1, 0, 3])
sp_col = np.array([0, 3, 1, 2, 2])
- sp_data = np.array([4, 5, 7, 9, 0])
+ sp_data = np.array([4, 5, 7, 9, 0], dtype=np.float64)
lst_tot = []
@@ -393,7 +407,8 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('argsort')
- A = nx.searchsorted(Mb, Mb, 'right')
+ tmp = nx.sort(Mb)
+ A = nx.searchsorted(tmp, tmp, 'right')
lst_b.append(nx.to_numpy(A))
lst_name.append('searchsorted')
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 830052d..f42ac6f 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -888,6 +888,7 @@ def test_implemented_methods():
ot.bregman.sinkhorn2(a, b, M, epsilon, method=method)
+@pytest.skip_backend("cupy")
@pytest.skip_backend("jax")
@pytest.mark.filterwarnings("ignore:Bottleneck")
def test_screenkhorn(nx):
diff --git a/test/test_gromov.py b/test/test_gromov.py
index c4bc04c..5c181f2 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -54,9 +54,12 @@ def test_gromov(nx):
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
gwb, logb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=True)
+ gwb = nx.to_numpy(gwb)
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
- gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ gw_valb = nx.to_numpy(
+ ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
+ )
G = log['T']
Gb = nx.to_numpy(logb['T'])
@@ -188,6 +191,7 @@ def test_entropic_gromov(nx):
C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True)
gwb, logb = ot.gromov.entropic_gromov_wasserstein2(
C1b, C2b, pb, qb, 'kl_loss', epsilon=1e-2, log=True)
+ gwb = nx.to_numpy(gwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])
@@ -287,8 +291,8 @@ def test_pointwise_gromov(nx):
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.0, atol=1e-08)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.0, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0, atol=1e-08)
G, log = ot.gromov.pointwise_gromov_wasserstein(
C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42)
@@ -298,8 +302,8 @@ def test_pointwise_gromov(nx):
Gb = nx.to_numpy(nx.todense(Gb))
np.testing.assert_allclose(G, Gb, atol=1e-06)
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.10342276348494964, atol=1e-8)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0015952535464736394, atol=1e-8)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.10342276348494964, atol=1e-8)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0015952535464736394, atol=1e-8)
@pytest.skip_backend("jax", reason="test very slow with jax backend")
@@ -346,8 +350,8 @@ def test_sampled_gromov(nx):
np.testing.assert_allclose(
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
- np.testing.assert_allclose(logb['gw_dist_estimated'], 0.05679474884977278, atol=1e-08)
- np.testing.assert_allclose(logb['gw_dist_std'], 0.0005986592106971995, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_estimated']), 0.05679474884977278, atol=1e-08)
+ np.testing.assert_allclose(float(logb['gw_dist_std']), 0.0005986592106971995, atol=1e-08)
def test_gromov_barycenter(nx):
@@ -486,6 +490,7 @@ def test_fgw(nx):
fgw, log = ot.gromov.fused_gromov_wasserstein2(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True)
fgwb, logb = ot.gromov.fused_gromov_wasserstein2(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True)
+ fgwb = nx.to_numpy(fgwb)
G = log['T']
Gb = nx.to_numpy(logb['T'])