summaryrefslogtreecommitdiff
path: root/ot/backend.py
diff options
context:
space:
mode:
authorNicolas Courty <ncourty@irisa.fr>2021-11-02 14:19:57 +0100
committerGitHub <noreply@github.com>2021-11-02 14:19:57 +0100
commit6775a527f9d3c801f8cdd805d8f205b6a75551b9 (patch)
treec0ed5a7c297b4003688fec52d46f918ea0086a7d /ot/backend.py
parenta335324d008e8982be61d7ace937815a2bfa98f9 (diff)
[MRG] Sliced and 1D Wasserstein distances : backend versions (#256)
* add numpy and torch backends * stat sets on functions * proper import * install recent torch on windows * install recent torch on windows * now testing all functions in backedn * add jax backedn * clenaup windowds * proper convert for jax backedn * pep8 * try again windows tests * test jax conversion * try proper widows tests * emd fuction ses backedn * better test partial OT * proper tests to_numpy and teplate Backend * pep8 * pep8 x2 * feaking sinkhorn works with torch * sinkhorn2 compatible * working ot.emd2 * important detach * it should work * jax autodiff emd * pep8 * no tast same for jax * new independat tests per backedn * freaking pep8 * add tests for gradients * deprecate ot.gpu * worging dist function * working dist * dist done in backedn * not in * remove indexing * change accuacy for jax * first pull backend * projection simplex * projection simplex * projection simplex * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 * add backedn discusion to quickstart guide * projection simplex no ci * projection simplex no ci * projection simplex no ci * pep8 + better doc * proper links * corect doctest * big debug documentation * doctest again * doctest again bis * doctest again ter (last one or i kill myself) * backend test + doc proj simplex * correction test_utils * correction test_utils * correction cumsum * correction flip * correction flip v2 * more debug * more debug * more debug + pep8 * pep8 * argh * proj_simplex * backedn works for sort * proj simplex * jax sucks * update doc * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/quickstart.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update test/test_utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/utils.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update docs/source/readme.rst Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * Update ot/lp/__init__.py Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> * begin comment alex * comment alex part 2 * optimize test gromov * proj_simplex on vectors * add awesome gradient decsnt example on the weights * pep98 of course * proof read example by alex * pep8 again * encoding oos in translation * correct legend * new backend functions for sliced * small indent pb * Optimized backendversion of sliced W * error in sliced W * after master merge * error sliced * error sliced * pep8 * test_sliced pep8 * doctest + precision for sliced * doctest * type win test_backend gather * type win test_backend gather * Update sliced.py change argument of padding pad_width * Update backend.py update redefinition * Update backend.py pep8 * Update backend.py pep 8 again.... * pep8 * build docs * emd2_1D example * refectoring emd_1d and variants * remove unused previous wasserstein_1d * pep8 * upate example * move stuff * tesys should work + implemù random backend * test random generayor functions * correction * better random generation * update sliced * update sliced * proper tests sliced * max sliced * chae file nam * add stuff * example sliced flow and barycenter * correct typo + update readme * exemple sliced flow done * pep8 * solver1d works * pep8 Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/backend.py')
-rw-r--r--ot/backend.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py
index 358297c..d3df44c 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -103,6 +103,8 @@ class Backend():
__name__ = None
__type__ = None
+ rng_ = None
+
def __str__(self):
return self.__name__
@@ -540,6 +542,36 @@ class Backend():
"""
raise NotImplementedError()
+ def seed(self, seed=None):
+ r"""
+ Sets the seed for the random generator.
+
+ This function follows the api from :any:`numpy.random.seed`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html
+ """
+ raise NotImplementedError()
+
+ def rand(self, *size, type_as=None):
+ r"""
+ Generate uniform random numbers.
+
+ This function follows the api from :any:`numpy.random.rand`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
+ """
+ raise NotImplementedError()
+
+ def randn(self, *size, type_as=None):
+ r"""
+ Generate normal Gaussian random numbers.
+
+ This function follows the api from :any:`numpy.random.rand`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html
+ """
+ raise NotImplementedError()
+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
r"""
Creates a sparse tensor in COOrdinate format.
@@ -632,6 +664,8 @@ class NumpyBackend(Backend):
__name__ = 'numpy'
__type__ = np.ndarray
+ rng_ = np.random.RandomState()
+
def to_numpy(self, a):
return a
@@ -793,6 +827,16 @@ class NumpyBackend(Backend):
def reshape(self, a, shape):
return np.reshape(a, shape)
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_.seed(seed)
+
+ def rand(self, *size, type_as=None):
+ return self.rng_.rand(*size)
+
+ def randn(self, *size, type_as=None):
+ return self.rng_.randn(*size)
+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
if type_as is None:
return coo_matrix((data, (rows, cols)), shape=shape)
@@ -845,6 +889,11 @@ class JaxBackend(Backend):
__name__ = 'jax'
__type__ = jax_type
+ rng_ = None
+
+ def __init__(self):
+ self.rng_ = jax.random.PRNGKey(42)
+
def to_numpy(self, a):
return np.array(a)
@@ -1010,6 +1059,24 @@ class JaxBackend(Backend):
def reshape(self, a, shape):
return jnp.reshape(a, shape)
+ def seed(self, seed=None):
+ if seed is not None:
+ self.rng_ = jax.random.PRNGKey(seed)
+
+ def rand(self, *size, type_as=None):
+ self.rng_, subkey = jax.random.split(self.rng_)
+ if type_as is not None:
+ return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype)
+ else:
+ return jax.random.uniform(subkey, shape=size)
+
+ def randn(self, *size, type_as=None):
+ self.rng_, subkey = jax.random.split(self.rng_)
+ if type_as is not None:
+ return jax.random.normal(subkey, shape=size, dtype=type_as.dtype)
+ else:
+ return jax.random.normal(subkey, shape=size)
+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
# Currently, JAX does not support sparse matrices
data = self.to_numpy(data)
@@ -1064,8 +1131,13 @@ class TorchBackend(Backend):
__name__ = 'torch'
__type__ = torch_type
+ rng_ = None
+
def __init__(self):
+ self.rng_ = torch.Generator()
+ self.rng_.seed()
+
from torch.autograd import Function
# define a function that takes inputs val and grads
@@ -1102,12 +1174,16 @@ class TorchBackend(Backend):
return res
def zeros(self, shape, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
if type_as is None:
return torch.zeros(shape)
else:
return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device)
def ones(self, shape, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
if type_as is None:
return torch.ones(shape)
else:
@@ -1120,6 +1196,8 @@ class TorchBackend(Backend):
return torch.arange(start, stop, step, device=type_as.device)
def full(self, shape, fill_value, type_as=None):
+ if isinstance(shape, int):
+ shape = (shape,)
if type_as is None:
return torch.full(shape, fill_value)
else:
@@ -1293,6 +1371,26 @@ class TorchBackend(Backend):
def reshape(self, a, shape):
return torch.reshape(a, shape)
+ def seed(self, seed=None):
+ if isinstance(seed, int):
+ self.rng_.manual_seed(seed)
+ elif isinstance(seed, torch.Generator):
+ self.rng_ = seed
+ else:
+ raise ValueError("Non compatible seed : {}".format(seed))
+
+ def rand(self, *size, type_as=None):
+ if type_as is not None:
+ return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device)
+ else:
+ return torch.rand(size=size, generator=self.rng_)
+
+ def randn(self, *size, type_as=None):
+ if type_as is not None:
+ return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device)
+ else:
+ return torch.randn(size=size, generator=self.rng_)
+
def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
if type_as is None:
return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape)