diff options
author | Nicolas Courty <ncourty@irisa.fr> | 2021-11-02 14:19:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-02 14:19:57 +0100 |
commit | 6775a527f9d3c801f8cdd805d8f205b6a75551b9 (patch) | |
tree | c0ed5a7c297b4003688fec52d46f918ea0086a7d /ot/backend.py | |
parent | a335324d008e8982be61d7ace937815a2bfa98f9 (diff) |
[MRG] Sliced and 1D Wasserstein distances : backend versions (#256)
* add numpy and torch backends
* stat sets on functions
* proper import
* install recent torch on windows
* install recent torch on windows
* now testing all functions in backedn
* add jax backedn
* clenaup windowds
* proper convert for jax backedn
* pep8
* try again windows tests
* test jax conversion
* try proper widows tests
* emd fuction ses backedn
* better test partial OT
* proper tests to_numpy and teplate Backend
* pep8
* pep8 x2
* feaking sinkhorn works with torch
* sinkhorn2 compatible
* working ot.emd2
* important detach
* it should work
* jax autodiff emd
* pep8
* no tast same for jax
* new independat tests per backedn
* freaking pep8
* add tests for gradients
* deprecate ot.gpu
* worging dist function
* working dist
* dist done in backedn
* not in
* remove indexing
* change accuacy for jax
* first pull backend
* projection simplex
* projection simplex
* projection simplex
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8
* add backedn discusion to quickstart guide
* projection simplex no ci
* projection simplex no ci
* projection simplex no ci
* pep8 + better doc
* proper links
* corect doctest
* big debug documentation
* doctest again
* doctest again bis
* doctest again ter (last one or i kill myself)
* backend test + doc proj simplex
* correction test_utils
* correction test_utils
* correction cumsum
* correction flip
* correction flip v2
* more debug
* more debug
* more debug + pep8
* pep8
* argh
* proj_simplex
* backedn works for sort
* proj simplex
* jax sucks
* update doc
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/quickstart.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update test/test_utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/utils.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update docs/source/readme.rst
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* Update ot/lp/__init__.py
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
* begin comment alex
* comment alex part 2
* optimize test gromov
* proj_simplex on vectors
* add awesome gradient decsnt example on the weights
* pep98 of course
* proof read example by alex
* pep8 again
* encoding oos in translation
* correct legend
* new backend functions for sliced
* small indent pb
* Optimized backendversion of sliced W
* error in sliced W
* after master merge
* error sliced
* error sliced
* pep8
* test_sliced pep8
* doctest + precision for sliced
* doctest
* type win test_backend gather
* type win test_backend gather
* Update sliced.py
change argument of padding pad_width
* Update backend.py
update redefinition
* Update backend.py
pep8
* Update backend.py
pep 8 again....
* pep8
* build docs
* emd2_1D example
* refectoring emd_1d and variants
* remove unused previous wasserstein_1d
* pep8
* upate example
* move stuff
* tesys should work + implemù random backend
* test random generayor functions
* correction
* better random generation
* update sliced
* update sliced
* proper tests sliced
* max sliced
* chae file nam
* add stuff
* example sliced flow and barycenter
* correct typo + update readme
* exemple sliced flow done
* pep8
* solver1d works
* pep8
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Diffstat (limited to 'ot/backend.py')
-rw-r--r-- | ot/backend.py | 98 |
1 files changed, 98 insertions, 0 deletions
diff --git a/ot/backend.py b/ot/backend.py index 358297c..d3df44c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -103,6 +103,8 @@ class Backend(): __name__ = None __type__ = None + rng_ = None + def __str__(self): return self.__name__ @@ -540,6 +542,36 @@ class Backend(): """ raise NotImplementedError() + def seed(self, seed=None): + r""" + Sets the seed for the random generator. + + This function follows the api from :any:`numpy.random.seed` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.seed.html + """ + raise NotImplementedError() + + def rand(self, *size, type_as=None): + r""" + Generate uniform random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + + def randn(self, *size, type_as=None): + r""" + Generate normal Gaussian random numbers. + + This function follows the api from :any:`numpy.random.rand` + + See: https://numpy.org/doc/stable/reference/generated/numpy.random.rand.html + """ + raise NotImplementedError() + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): r""" Creates a sparse tensor in COOrdinate format. @@ -632,6 +664,8 @@ class NumpyBackend(Backend): __name__ = 'numpy' __type__ = np.ndarray + rng_ = np.random.RandomState() + def to_numpy(self, a): return a @@ -793,6 +827,16 @@ class NumpyBackend(Backend): def reshape(self, a, shape): return np.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_.seed(seed) + + def rand(self, *size, type_as=None): + return self.rng_.rand(*size) + + def randn(self, *size, type_as=None): + return self.rng_.randn(*size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return coo_matrix((data, (rows, cols)), shape=shape) @@ -845,6 +889,11 @@ class JaxBackend(Backend): __name__ = 'jax' __type__ = jax_type + rng_ = None + + def __init__(self): + self.rng_ = jax.random.PRNGKey(42) + def to_numpy(self, a): return np.array(a) @@ -1010,6 +1059,24 @@ class JaxBackend(Backend): def reshape(self, a, shape): return jnp.reshape(a, shape) + def seed(self, seed=None): + if seed is not None: + self.rng_ = jax.random.PRNGKey(seed) + + def rand(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.uniform(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.uniform(subkey, shape=size) + + def randn(self, *size, type_as=None): + self.rng_, subkey = jax.random.split(self.rng_) + if type_as is not None: + return jax.random.normal(subkey, shape=size, dtype=type_as.dtype) + else: + return jax.random.normal(subkey, shape=size) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): # Currently, JAX does not support sparse matrices data = self.to_numpy(data) @@ -1064,8 +1131,13 @@ class TorchBackend(Backend): __name__ = 'torch' __type__ = torch_type + rng_ = None + def __init__(self): + self.rng_ = torch.Generator() + self.rng_.seed() + from torch.autograd import Function # define a function that takes inputs val and grads @@ -1102,12 +1174,16 @@ class TorchBackend(Backend): return res def zeros(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.zeros(shape) else: return torch.zeros(shape, dtype=type_as.dtype, device=type_as.device) def ones(self, shape, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.ones(shape) else: @@ -1120,6 +1196,8 @@ class TorchBackend(Backend): return torch.arange(start, stop, step, device=type_as.device) def full(self, shape, fill_value, type_as=None): + if isinstance(shape, int): + shape = (shape,) if type_as is None: return torch.full(shape, fill_value) else: @@ -1293,6 +1371,26 @@ class TorchBackend(Backend): def reshape(self, a, shape): return torch.reshape(a, shape) + def seed(self, seed=None): + if isinstance(seed, int): + self.rng_.manual_seed(seed) + elif isinstance(seed, torch.Generator): + self.rng_ = seed + else: + raise ValueError("Non compatible seed : {}".format(seed)) + + def rand(self, *size, type_as=None): + if type_as is not None: + return torch.rand(size=size, generator=self.rng_, dtype=type_as.dtype, device=type_as.device) + else: + return torch.rand(size=size, generator=self.rng_) + + def randn(self, *size, type_as=None): + if type_as is not None: + return torch.randn(size=size, dtype=type_as.dtype, generator=self.rng_, device=type_as.device) + else: + return torch.randn(size=size, generator=self.rng_) + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): if type_as is None: return torch.sparse_coo_tensor(torch.stack([rows, cols]), data, size=shape) |