diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2021-12-09 17:55:12 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-12-09 17:55:12 +0100 |
commit | f8d871e8c6f15009f559ece6a12eb8d8891c60fb (patch) | |
tree | 9aa46b2fcc8046c6cddd8e9159a6f607dcf0e1e9 /ot | |
parent | b3dc68feac355fa94c4237f4ecad65edc9f7a7e8 (diff) |
[MRG] Tensorflow backend & Benchmarker & Myst_parser (#316)
* First batch of tf methods (to be continued)
* Second batch of method (yet to debug)
* tensorflow for cpu
* add tf requirement
* pep8 + bug
* small changes
* attempt to solve pymanopt bug with tf2
* attempt #2
* attempt #3
* attempt 4
* docstring
* correct pep8 violation introduced in merge conflicts resolution
* attempt 5
* attempt 6
* just a random try
* Revert "just a random try"
This reverts commit 8223e768bfe33635549fb66cca2267514a60ebbf.
* GPU tests for tensorflow
* pep8
* attempt to solve issue with m2r2
* Remove transpose backend method
* first draft of benchmarker (need to correct time measurement)
* prettier bench table
* Bitsize and prettier device methods
* prettified table bench
* Bug corrected (results were mixed up in the final table)
* Better perf counter (for GPU support)
* pep8
* EMD bench
* solve bug if no GPU available
* pep8
* warning about tensorflow numpy api being required in the backend.py docstring
* Bug solve in backend docstring
* not covering code which requires a GPU
* Tensorflow gradients manipulation tested
* Number of warmup runs is now customizable
* typo
* Remove some warnings while building docs
* Change prettier_device to device_type in backend
* Correct JAX mistakes preventing to see the CPU if a GPU is present
* Attempt to solve JAX bug in case no GPU is found
* Reworked benchmarks order and results storage & clear GPU after usage by benchmark
* Add bench to backend docstring
* better benchs
* remove useless stuff
* Better device_type
* Now using MYST_PARSER and solving links issue in the README.md / online docs
Diffstat (limited to 'ot')
-rw-r--r-- | ot/backend.py | 580 | ||||
-rw-r--r-- | ot/bregman.py | 72 | ||||
-rw-r--r-- | ot/da.py | 44 | ||||
-rw-r--r-- | ot/datasets.py | 2 | ||||
-rw-r--r-- | ot/dr.py | 2 | ||||
-rw-r--r-- | ot/gromov.py | 2 | ||||
-rw-r--r-- | ot/lp/solver_1d.py | 4 | ||||
-rw-r--r-- | ot/plot.py | 4 |
8 files changed, 645 insertions, 65 deletions
diff --git a/ot/backend.py b/ot/backend.py index 1630ac4..58b652b 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -3,7 +3,7 @@ Multi-lib backend for POT The goal is to write backend-agnostic code. Whether you're using Numpy, PyTorch, -Jax, or Cupy, POT code should work nonetheless. +Jax, Cupy, or Tensorflow, POT code should work nonetheless. To achieve that, POT provides backend classes which implements functions in their respective backend imitating Numpy API. As a convention, we use nx instead of np to refer to the backend. @@ -17,6 +17,68 @@ Examples ... nx = get_backend(a, b) # infer the backend from the arguments ... c = nx.dot(a, b) # now use the backend to do any calculation ... return c + +.. warning:: + Tensorflow only works with the Numpy API. To activate it, please run the following: + + .. code-block:: + + from tensorflow.python.ops.numpy_ops import np_config + np_config.enable_numpy_behavior() + +Performance +-------- + +- CPU: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz +- GPU: Tesla V100-SXM2-32GB +- Date of the benchmark: December 8th, 2021 +- Commit of benchmark: PR #316, https://github.com/PythonOT/POT/pull/316 + +.. raw:: html + + <style> + #perftable { + width: 100%; + margin-bottom: 1em; + } + + #perftable table{ + border-collapse: collapse; + table-layout: fixed; + width: 100%; + } + + #perftable th, #perftable td { + border: 1px solid #ddd; + padding: 8px; + font-size: smaller; + } + </style> + + <div id="perftable"> + <table> + <tr><th align="center" colspan="8">Sinkhorn Knopp - Averaged on 100 runs</th></tr> + <tr><th align="center">Bitsize</th><th align="center" colspan="7">32 bits</th></tr> + <tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr> + <tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr> + <tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0022</td><td align="center">0.0151</td><td align="center">0.0095</td><td align="center">0.0193</td><td align="center">0.0051</td><td align="center">0.0293</td></tr> + <tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0097</td><td align="center">0.0057</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0173</td></tr> + <tr><td align="center">500</td><td align="center">0.0009</td><td align="center">0.0016</td><td align="center">0.0110</td><td align="center">0.0058</td><td align="center">0.0115</td><td align="center">0.0029</td><td align="center">0.0166</td></tr> + <tr><td align="center">1000</td><td align="center">0.0021</td><td align="center">0.0021</td><td align="center">0.0145</td><td align="center">0.0056</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0168</td></tr> + <tr><td align="center">2000</td><td align="center">0.0069</td><td align="center">0.0043</td><td align="center">0.0278</td><td align="center">0.0059</td><td align="center">0.0118</td><td align="center">0.0030</td><td align="center">0.0165</td></tr> + <tr><td align="center">5000</td><td align="center">0.0707</td><td align="center">0.0314</td><td align="center">0.1395</td><td align="center">0.0074</td><td align="center">0.0125</td><td align="center">0.0035</td><td align="center">0.0198</td></tr> + <tr><td colspan="8"> </td></tr> + <tr><th align="center">Bitsize</th><th align="center" colspan="7">64 bits</th></tr> + <tr><th align="center">Device</th><th align="center" colspan="3.0"">CPU</th><th align="center" colspan="4.0">GPU</tr> + <tr><th align="center">Sample size</th><th align="center">Numpy</th><th align="center">Pytorch</th><th align="center">Tensorflow</th><th align="center">Cupy</th><th align="center">Jax</th><th align="center">Pytorch</th><th align="center">Tensorflow</th></tr> + <tr><td align="center">50</td><td align="center">0.0008</td><td align="center">0.0020</td><td align="center">0.0154</td><td align="center">0.0093</td><td align="center">0.0191</td><td align="center">0.0051</td><td align="center">0.0328</td></tr> + <tr><td align="center">100</td><td align="center">0.0005</td><td align="center">0.0013</td><td align="center">0.0094</td><td align="center">0.0056</td><td align="center">0.0114</td><td align="center">0.0029</td><td align="center">0.0169</td></tr> + <tr><td align="center">500</td><td align="center">0.0013</td><td align="center">0.0017</td><td align="center">0.0120</td><td align="center">0.0059</td><td align="center">0.0116</td><td align="center">0.0029</td><td align="center">0.0168</td></tr> + <tr><td align="center">1000</td><td align="center">0.0034</td><td align="center">0.0027</td><td align="center">0.0177</td><td align="center">0.0058</td><td align="center">0.0118</td><td align="center">0.0029</td><td align="center">0.0167</td></tr> + <tr><td align="center">2000</td><td align="center">0.0146</td><td align="center">0.0075</td><td align="center">0.0436</td><td align="center">0.0059</td><td align="center">0.0120</td><td align="center">0.0029</td><td align="center">0.0165</td></tr> + <tr><td align="center">5000</td><td align="center">0.1467</td><td align="center">0.0568</td><td align="center">0.2468</td><td align="center">0.0077</td><td align="center">0.0146</td><td align="center">0.0045</td><td align="center">0.0204</td></tr> + </table> + </div> """ # Author: Remi Flamary <remi.flamary@polytechnique.edu> @@ -27,6 +89,8 @@ Examples import numpy as np import scipy.special as scipy from scipy.sparse import issparse, coo_matrix, csr_matrix +import warnings +import time try: import torch @@ -39,6 +103,7 @@ try: import jax import jax.numpy as jnp import jax.scipy.special as jscipy + from jax.lib import xla_bridge jax_type = jax.numpy.ndarray except ImportError: jax = False @@ -52,6 +117,15 @@ except ImportError: cp = False cp_type = float +try: + import tensorflow as tf + import tensorflow.experimental.numpy as tnp + tf_type = tf.Tensor +except ImportError: + tf = False + tf_type = float + + str_type_error = "All array should be from the same type/backend. Current types are : {}" @@ -65,9 +139,12 @@ def get_backend_list(): if jax: lst.append(JaxBackend()) - if cp: + if cp: # pragma: no cover lst.append(CupyBackend()) + if tf: + lst.append(TensorflowBackend()) + return lst @@ -89,8 +166,10 @@ def get_backend(*args): return TorchBackend() elif isinstance(args[0], jax_type): return JaxBackend() - elif isinstance(args[0], cp_type): + elif isinstance(args[0], cp_type): # pragma: no cover return CupyBackend() + elif isinstance(args[0], tf_type): + return TensorflowBackend() else: raise ValueError("Unknown type of non implemented backend.") @@ -108,7 +187,7 @@ class Backend(): """ Backend abstract class. Implementations: :py:class:`JaxBackend`, :py:class:`NumpyBackend`, :py:class:`TorchBackend`, - :py:class:`CupyBackend` + :py:class:`CupyBackend`, :py:class:`TensorflowBackend` - The `__name__` class attribute refers to the name of the backend. - The `__type__` class attribute refers to the data structure used by the backend. @@ -679,6 +758,34 @@ class Backend(): """ raise NotImplementedError() + def squeeze(self, a, axis=None): + r""" + Remove axes of length one from a. + + This function follows the api from :any:`numpy.squeeze`. + + See: https://numpy.org/doc/stable/reference/generated/numpy.squeeze.html + """ + raise NotImplementedError() + + def bitsize(self, type_as): + r""" + Gives the number of bits used by the data type of the given tensor. + """ + raise NotImplementedError() + + def device_type(self, type_as): + r""" + Returns CPU or GPU depending on the device where the given tensor is located. + """ + raise NotImplementedError() + + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): + r""" + Executes a benchmark of the given callable with the given arguments. + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -916,6 +1023,29 @@ class NumpyBackend(Backend): # numpy has implicit type conversion so we automatically validate the test pass + def squeeze(self, a, axis=None): + return np.squeeze(a, axis=axis) + + def bitsize(self, type_as): + return type_as.itemsize * 8 + + def device_type(self, type_as): + return "CPU" + + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): + results = dict() + for type_as in self.__type_list__: + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + for _ in range(warmup_runs): + callable(*inputs) + t0 = time.perf_counter() + for _ in range(n_runs): + callable(*inputs) + t1 = time.perf_counter() + key = ("Numpy", self.device_type(type_as), self.bitsize(type_as)) + results[key] = (t1 - t0) / n_runs + return results + class JaxBackend(Backend): """ @@ -934,9 +1064,16 @@ class JaxBackend(Backend): def __init__(self): self.rng_ = jax.random.PRNGKey(42) - for d in jax.devices(): - self.__type_list__ = [jax.device_put(jnp.array(1, dtype=jnp.float32), d), - jax.device_put(jnp.array(1, dtype=jnp.float64), d)] + self.__type_list__ = [] + # available_devices = jax.devices("cpu") + available_devices = [] + if xla_bridge.get_backend().platform == "gpu": + available_devices += jax.devices("gpu") + for d in available_devices: + self.__type_list__ += [ + jax.device_put(jnp.array(1, dtype=jnp.float32), d), + jax.device_put(jnp.array(1, dtype=jnp.float64), d) + ] def to_numpy(self, a): return np.array(a) @@ -1176,6 +1313,32 @@ class JaxBackend(Backend): assert a_dtype == b_dtype, "Dtype discrepancy" assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + def squeeze(self, a, axis=None): + return jnp.squeeze(a, axis=axis) + + def bitsize(self, type_as): + return type_as.dtype.itemsize * 8 + + def device_type(self, type_as): + return self.dtype_device(type_as)[1].platform.upper() + + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): + results = dict() + + for type_as in self.__type_list__: + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + for _ in range(warmup_runs): + a = callable(*inputs) + a.block_until_ready() + t0 = time.perf_counter() + for _ in range(n_runs): + a = callable(*inputs) + a.block_until_ready() + t1 = time.perf_counter() + key = ("Jax", self.device_type(type_as), self.bitsize(type_as)) + results[key] = (t1 - t0) / n_runs + return results + class TorchBackend(Backend): """ @@ -1515,6 +1678,46 @@ class TorchBackend(Backend): assert a_dtype == b_dtype, "Dtype discrepancy" assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + def squeeze(self, a, axis=None): + if axis is None: + return torch.squeeze(a) + else: + return torch.squeeze(a, dim=axis) + + def bitsize(self, type_as): + return torch.finfo(type_as.dtype).bits + + def device_type(self, type_as): + return type_as.device.type.replace("cuda", "gpu").upper() + + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): + results = dict() + for type_as in self.__type_list__: + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + for _ in range(warmup_runs): + callable(*inputs) + if self.device_type(type_as) == "GPU": # pragma: no cover + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + else: + start = time.perf_counter() + for _ in range(n_runs): + callable(*inputs) + if self.device_type(type_as) == "GPU": # pragma: no cover + end.record() + torch.cuda.synchronize() + duration = start.elapsed_time(end) / 1000. + else: + end = time.perf_counter() + duration = end - start + key = ("Pytorch", self.device_type(type_as), self.bitsize(type_as)) + results[key] = duration / n_runs + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return results + class CupyBackend(Backend): # pragma: no cover """ @@ -1798,3 +2001,366 @@ class CupyBackend(Backend): # pragma: no cover # cupy has implicit type conversion so # we automatically validate the test for type assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + + def squeeze(self, a, axis=None): + return cp.squeeze(a, axis=axis) + + def bitsize(self, type_as): + return type_as.itemsize * 8 + + def device_type(self, type_as): + return "GPU" + + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): + mempool = cp.get_default_memory_pool() + pinned_mempool = cp.get_default_pinned_memory_pool() + + results = dict() + for type_as in self.__type_list__: + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + start_gpu = cp.cuda.Event() + end_gpu = cp.cuda.Event() + for _ in range(warmup_runs): + callable(*inputs) + start_gpu.synchronize() + start_gpu.record() + for _ in range(n_runs): + callable(*inputs) + end_gpu.record() + end_gpu.synchronize() + key = ("Cupy", self.device_type(type_as), self.bitsize(type_as)) + t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) / 1000. + results[key] = t_gpu / n_runs + mempool.free_all_blocks() + pinned_mempool.free_all_blocks() + return results + + +class TensorflowBackend(Backend): + + __name__ = "tf" + __type__ = tf_type + __type_list__ = None + + rng_ = None + + def __init__(self): + self.seed(None) + + self.__type_list__ = [ + tf.convert_to_tensor([1], dtype=tf.float32), + tf.convert_to_tensor([1], dtype=tf.float64) + ] + + tmp = self.randn(15, 10) + try: + tmp.reshape((150, 1)) + except AttributeError: + warnings.warn( + "To use TensorflowBackend, you need to activate the tensorflow " + "numpy API. You can activate it by running: \n" + "from tensorflow.python.ops.numpy_ops import np_config\n" + "np_config.enable_numpy_behavior()" + ) + + def to_numpy(self, a): + return a.numpy() + + def from_numpy(self, a, type_as=None): + if not isinstance(a, self.__type__): + if type_as is None: + return tf.convert_to_tensor(a) + else: + return tf.convert_to_tensor(a, dtype=type_as.dtype) + else: + if type_as is None: + return a + else: + return tf.cast(a, dtype=type_as.dtype) + + def set_gradients(self, val, inputs, grads): + @tf.custom_gradient + def tmp(input): + def grad(upstream): + return grads + return val, grad + return tmp(inputs) + + def zeros(self, shape, type_as=None): + if type_as is None: + return tnp.zeros(shape) + else: + return tnp.zeros(shape, dtype=type_as.dtype) + + def ones(self, shape, type_as=None): + if type_as is None: + return tnp.ones(shape) + else: + return tnp.ones(shape, dtype=type_as.dtype) + + def arange(self, stop, start=0, step=1, type_as=None): + return tnp.arange(start, stop, step) + + def full(self, shape, fill_value, type_as=None): + if type_as is None: + return tnp.full(shape, fill_value) + else: + return tnp.full(shape, fill_value, dtype=type_as.dtype) + + def eye(self, N, M=None, type_as=None): + if type_as is None: + return tnp.eye(N, M) + else: + return tnp.eye(N, M, dtype=type_as.dtype) + + def sum(self, a, axis=None, keepdims=False): + return tnp.sum(a, axis, keepdims=keepdims) + + def cumsum(self, a, axis=None): + return tnp.cumsum(a, axis) + + def max(self, a, axis=None, keepdims=False): + return tnp.max(a, axis, keepdims=keepdims) + + def min(self, a, axis=None, keepdims=False): + return tnp.min(a, axis, keepdims=keepdims) + + def maximum(self, a, b): + return tnp.maximum(a, b) + + def minimum(self, a, b): + return tnp.minimum(a, b) + + def dot(self, a, b): + if len(b.shape) == 1: + if len(a.shape) == 1: + # inner product + return tf.reduce_sum(tf.multiply(a, b)) + else: + # matrix vector + return tf.linalg.matvec(a, b) + else: + if len(a.shape) == 1: + return tf.linalg.matvec(b.T, a.T).T + else: + return tf.matmul(a, b) + + def abs(self, a): + return tnp.abs(a) + + def exp(self, a): + return tnp.exp(a) + + def log(self, a): + return tnp.log(a) + + def sqrt(self, a): + return tnp.sqrt(a) + + def power(self, a, exponents): + return tnp.power(a, exponents) + + def norm(self, a): + return tf.math.reduce_euclidean_norm(a) + + def any(self, a): + return tnp.any(a) + + def isnan(self, a): + return tnp.isnan(a) + + def isinf(self, a): + return tnp.isinf(a) + + def einsum(self, subscripts, *operands): + return tnp.einsum(subscripts, *operands) + + def sort(self, a, axis=-1): + return tnp.sort(a, axis) + + def argsort(self, a, axis=-1): + return tnp.argsort(a, axis) + + def searchsorted(self, a, v, side='left'): + return tf.searchsorted(a, v, side=side) + + def flip(self, a, axis=None): + return tnp.flip(a, axis) + + def outer(self, a, b): + return tnp.outer(a, b) + + def clip(self, a, a_min, a_max): + return tnp.clip(a, a_min, a_max) + + def repeat(self, a, repeats, axis=None): + return tnp.repeat(a, repeats, axis) + + def take_along_axis(self, arr, indices, axis): + return tnp.take_along_axis(arr, indices, axis) + + def concatenate(self, arrays, axis=0): + return tnp.concatenate(arrays, axis) + + def zero_pad(self, a, pad_width): + return tnp.pad(a, pad_width, mode="constant") + + def argmax(self, a, axis=None): + return tnp.argmax(a, axis=axis) + + def mean(self, a, axis=None): + return tnp.mean(a, axis=axis) + + def std(self, a, axis=None): + return tnp.std(a, axis=axis) + + def linspace(self, start, stop, num): + return tnp.linspace(start, stop, num) + + def meshgrid(self, a, b): + return tnp.meshgrid(a, b) + + def diag(self, a, k=0): + return tnp.diag(a, k) + + def unique(self, a): + return tf.sort(tf.unique(tf.reshape(a, [-1]))[0]) + + def logsumexp(self, a, axis=None): + return tf.math.reduce_logsumexp(a, axis=axis) + + def stack(self, arrays, axis=0): + return tnp.stack(arrays, axis) + + def reshape(self, a, shape): + return tnp.reshape(a, shape) + + def seed(self, seed=None): + if isinstance(seed, int): + self.rng_ = tf.random.Generator.from_seed(seed) + elif isinstance(seed, tf.random.Generator): + self.rng_ = seed + elif seed is None: + self.rng_ = tf.random.Generator.from_non_deterministic_state() + else: + raise ValueError("Non compatible seed : {}".format(seed)) + + def rand(self, *size, type_as=None): + if type_as is None: + return self.rng_.uniform(size, minval=0., maxval=1.) + else: + return self.rng_.uniform( + size, minval=0., maxval=1., dtype=type_as.dtype + ) + + def randn(self, *size, type_as=None): + if type_as is None: + return self.rng_.normal(size) + else: + return self.rng_.normal(size, dtype=type_as.dtype) + + def _convert_to_index_for_coo(self, tensor): + if isinstance(tensor, self.__type__): + return int(self.max(tensor)) + 1 + else: + return int(np.max(tensor)) + 1 + + def coo_matrix(self, data, rows, cols, shape=None, type_as=None): + if shape is None: + shape = ( + self._convert_to_index_for_coo(rows), + self._convert_to_index_for_coo(cols) + ) + if type_as is not None: + data = self.from_numpy(data, type_as=type_as) + + sparse_tensor = tf.sparse.SparseTensor( + indices=tnp.stack([rows, cols]).T, + values=data, + dense_shape=shape + ) + # if type_as is not None: + # sparse_tensor = self.from_numpy(sparse_tensor, type_as=type_as) + # SparseTensor are not subscriptable so we use dense tensors + return self.todense(sparse_tensor) + + def issparse(self, a): + return isinstance(a, tf.sparse.SparseTensor) + + def tocsr(self, a): + return a + + def eliminate_zeros(self, a, threshold=0.): + if self.issparse(a): + values = a.values + if threshold > 0: + mask = self.abs(values) <= threshold + else: + mask = values == 0 + return tf.sparse.retain(a, ~mask) + else: + if threshold > 0: + a = tnp.where(self.abs(a) > threshold, a, 0.) + return a + + def todense(self, a): + if self.issparse(a): + return tf.sparse.to_dense(tf.sparse.reorder(a)) + else: + return a + + def where(self, condition, x, y): + return tnp.where(condition, x, y) + + def copy(self, a): + return tf.identity(a) + + def allclose(self, a, b, rtol=1e-05, atol=1e-08, equal_nan=False): + return tnp.allclose( + a, b, rtol=rtol, atol=atol, equal_nan=equal_nan + ) + + def dtype_device(self, a): + return a.dtype, a.device.split("device:")[1] + + def assert_same_dtype_device(self, a, b): + a_dtype, a_device = self.dtype_device(a) + b_dtype, b_device = self.dtype_device(b) + + assert a_dtype == b_dtype, "Dtype discrepancy" + assert a_device == b_device, f"Device discrepancy. First input is on {str(a_device)}, whereas second input is on {str(b_device)}" + + def squeeze(self, a, axis=None): + return tnp.squeeze(a, axis=axis) + + def bitsize(self, type_as): + return type_as.dtype.size * 8 + + def device_type(self, type_as): + return self.dtype_device(type_as)[1].split(":")[0] + + def _bench(self, callable, *args, n_runs=1, warmup_runs=1): + results = dict() + device_contexts = [tf.device("/CPU:0")] + if len(tf.config.list_physical_devices('GPU')) > 0: # pragma: no cover + device_contexts.append(tf.device("/GPU:0")) + + for device_context in device_contexts: + with device_context: + for type_as in self.__type_list__: + inputs = [self.from_numpy(arg, type_as=type_as) for arg in args] + for _ in range(warmup_runs): + callable(*inputs) + t0 = time.perf_counter() + for _ in range(n_runs): + res = callable(*inputs) + _ = res.numpy() + t1 = time.perf_counter() + key = ( + "Tensorflow", + self.device_type(inputs[0]), + self.bitsize(type_as) + ) + results[key] = (t1 - t0) / n_runs + + return results diff --git a/ot/bregman.py b/ot/bregman.py index cce52e2..fc20175 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -830,9 +830,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) - if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received. Greenkhorn is not " - "compatible with JAX") + if nx.__name__ in ("jax", "tf"): + raise TypeError("JAX or TF arrays have been received. Greenkhorn is not " + "compatible with neither JAX nor TF") if len(a) == 0: a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] @@ -865,20 +865,20 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, if m_viol_1 > m_viol_2: old_u = u[i_1] - new_u = a[i_1] / (K[i_1, :].dot(v)) + new_u = a[i_1] / nx.dot(K[i_1, :], v) G[i_1, :] = new_u * K[i_1, :] * v - viol[i_1] = new_u * K[i_1, :].dot(v) - a[i_1] + viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1] viol_2 += (K[i_1, :].T * (new_u - old_u) * v) u[i_1] = new_u else: old_v = v[i_2] - new_v = b[i_2] / (K[:, i_2].T.dot(u)) + new_v = b[i_2] / nx.dot(K[:, i_2].T, u) G[:, i_2] = u * K[:, i_2] * new_v # aviol = (G@one_m - a) # aviol_2 = (G.T@one_n - b) viol += (-old_v + new_v) * K[:, i_2] * u - viol_2[i_2] = new_v * K[:, i_2].dot(u) - b[i_2] + viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2] v[i_2] = new_v if stopThr_val <= stopThr: @@ -1550,9 +1550,11 @@ def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, nx = get_backend(A, M) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and tf. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists @@ -1886,9 +1888,11 @@ def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, dim, n_hists = A.shape nx = get_backend(A, M) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists @@ -2043,7 +2047,7 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, log = {'err': []} bar = nx.ones(A.shape[1:], type_as=A) - bar /= bar.sum() + bar /= nx.sum(bar) U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) err = 1 @@ -2069,9 +2073,11 @@ def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) - bar = nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + bar = nx.exp( + nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) + ) if ii % 10 == 9: - err = (V * KU).std(axis=0).sum() + err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: log['err'].append(err) @@ -2106,9 +2112,11 @@ def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, A = list_to_array(A) nx = get_backend(A) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) n_hists, width, height = A.shape @@ -2298,13 +2306,15 @@ def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) - bar = c * nx.exp((weights[:, None, None] * nx.log(KU + stabThr)).sum(axis=0)) + bar = c * nx.exp( + nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) + ) for _ in range(10): - c = (c * bar / convol_imgs(c[None]).squeeze()) ** 0.5 + c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5 if ii % 10 == 9: - err = (V * KU).std(axis=0).sum() + err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: log['err'].append(err) @@ -2340,9 +2350,11 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) - if nx.__name__ == "jax": - raise NotImplementedError("Log-domain functions are not yet implemented" - " for Jax. Use numpy or torch arrays instead.") + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: @@ -2382,7 +2394,7 @@ def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10 c = 0.5 * (c + log_bar - convol_img(c)) if ii % 10 == 9: - err = nx.exp(G + log_KU).std(axis=0).sum() + err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0)) # log and verbose print if log: log['err'].append(err) @@ -3312,9 +3324,9 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) - if nx.__name__ == "jax": - raise TypeError("JAX arrays have been received but screenkhorn is not " - "compatible with JAX.") + if nx.__name__ in ("jax", "tf"): + raise TypeError("JAX or TF arrays have been received but screenkhorn is not " + "compatible with neither JAX nor TF.") ns, nt = M.shape @@ -3328,7 +3340,7 @@ def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, K = nx.exp(-M / reg) def projection(u, epsilon): - u[u <= epsilon] = epsilon + u = nx.maximum(u, epsilon) return u # ----------------------------------------------------------------------------------------------------------------# @@ -906,7 +906,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al def distribution_estimation_uniform(X): - """estimates a uniform distribution from an array of samples :math:`\mathbf{X}` + r"""estimates a uniform distribution from an array of samples :math:`\mathbf{X}` Parameters ---------- @@ -950,7 +950,7 @@ class BaseTransport(BaseEstimator): """ def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1010,7 +1010,7 @@ class BaseTransport(BaseEstimator): return self def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` and transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` @@ -1038,7 +1038,7 @@ class BaseTransport(BaseEstimator): return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt) def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -1105,7 +1105,7 @@ class BaseTransport(BaseEstimator): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in + r"""Propagate source labels :math:`\mathbf{y_s}` to obtain estimated target labels as in :ref:`[27] <references-basetransport-transform-labels>`. Parameters @@ -1152,7 +1152,7 @@ class BaseTransport(BaseEstimator): def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` + r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters ---------- @@ -1218,7 +1218,7 @@ class BaseTransport(BaseEstimator): return transp_Xt def inverse_transform_labels(self, yt=None): - """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels + r"""Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels :math:`\mathbf{y_s}` Parameters @@ -1307,7 +1307,7 @@ class LinearTransport(BaseTransport): self.distribution_estimation = distribution_estimation def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1354,7 +1354,7 @@ class LinearTransport(BaseTransport): return self def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -1387,7 +1387,7 @@ class LinearTransport(BaseTransport): def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` + r"""Transports target samples :math:`\mathbf{X_t}` onto source samples :math:`\mathbf{X_s}` Parameters ---------- @@ -1493,7 +1493,7 @@ class SinkhornTransport(BaseTransport): self.out_of_sample_map = out_of_sample_map def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1592,7 +1592,7 @@ class EMDTransport(BaseTransport): self.max_iter = max_iter def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1711,7 +1711,7 @@ class SinkhornLpl1Transport(BaseTransport): self.limit_max = limit_max def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1839,7 +1839,7 @@ class EMDLaplaceTransport(BaseTransport): self.out_of_sample_map = out_of_sample_map def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -1962,7 +1962,7 @@ class SinkhornL1l2Transport(BaseTransport): self.limit_max = limit_max def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -2088,7 +2088,7 @@ class MappingTransport(BaseEstimator): self.verbose2 = verbose2 def fit(self, Xs=None, ys=None, Xt=None, yt=None): - """Builds an optimal coupling and estimates the associated mapping + r"""Builds an optimal coupling and estimates the associated mapping from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` @@ -2146,7 +2146,7 @@ class MappingTransport(BaseEstimator): return self def transform(self, Xs): - """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -2261,7 +2261,7 @@ class UnbalancedSinkhornTransport(BaseTransport): self.limit_max = limit_max def fit(self, Xs, ys=None, Xt=None, yt=None): - """Build a coupling matrix from source and target sets of samples + r"""Build a coupling matrix from source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -2373,7 +2373,7 @@ class JCPOTTransport(BaseTransport): self.out_of_sample_map = out_of_sample_map def fit(self, Xs, ys=None, Xt=None, yt=None): - """Building coupling matrices from a list of source and target sets of samples + r"""Building coupling matrices from a list of source and target sets of samples :math:`(\mathbf{X_s}, \mathbf{y_s})` and :math:`(\mathbf{X_t}, \mathbf{y_t})` Parameters @@ -2419,7 +2419,7 @@ class JCPOTTransport(BaseTransport): return self def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): - """Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` + r"""Transports source samples :math:`\mathbf{X_s}` onto target ones :math:`\mathbf{X_t}` Parameters ---------- @@ -2491,7 +2491,7 @@ class JCPOTTransport(BaseTransport): return transp_Xs def transform_labels(self, ys=None): - """Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in + r"""Propagate source labels :math:`\mathbf{y_s}` to obtain target labels as in :ref:`[27] <references-jcpottransport-transform-labels>` Parameters @@ -2542,7 +2542,7 @@ class JCPOTTransport(BaseTransport): return yt.T def inverse_transform_labels(self, yt=None): - """Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels + r"""Propagate target labels :math:`\mathbf{y_t}` to obtain estimated source labels :math:`\mathbf{y_s}` Parameters diff --git a/ot/datasets.py b/ot/datasets.py index ad6390c..a839074 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -41,7 +41,7 @@ def get_1D_gauss(n, m, sigma): def make_2D_samples_gauss(n, m, sigma, random_state=None): - """Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)` + r"""Return `n` samples drawn from 2D gaussian :math:`\mathcal{N}(m, \sigma)` Parameters ---------- @@ -16,6 +16,7 @@ Dimension reduction with OT from scipy import linalg import autograd.numpy as np +from pymanopt.function import Autograd from pymanopt.manifolds import Stiefel from pymanopt import Problem from pymanopt.solvers import SteepestDescent, TrustRegions @@ -181,6 +182,7 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no else: regmean = np.ones((len(xc), len(xc))) + @Autograd def cost(P): # wda loss loss_b = 0 diff --git a/ot/gromov.py b/ot/gromov.py index dc95c74..6544260 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -947,7 +947,7 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, index[0] = generator.choice(len_p, size=1, p=nx.to_numpy(p))
T_index0 = nx.reshape(nx.todense(T[index[0], :]), (-1,))
index[1] = generator.choice(
- len_q, size=1, p=nx.to_numpy(T_index0 / T_index0.sum())
+ len_q, size=1, p=nx.to_numpy(T_index0 / nx.sum(T_index0))
)
if alpha == 1:
diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 8b4d0c3..43763a9 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -100,11 +100,11 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ m = v_values.shape[0] if u_weights is None: - u_weights = nx.full(u_values.shape, 1. / n) + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) if v_weights is None: - v_weights = nx.full(v_values.shape, 1. / m) + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) @@ -18,7 +18,7 @@ from matplotlib import gridspec def plot1D_mat(a, b, M, title=''): - """ Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution + r""" Plot matrix :math:`\mathbf{M}` with the source and target 1D distribution Creates a subplot with the source distribution :math:`\mathbf{a}` on the left and target distribution :math:`\mathbf{b}` on the top. The matrix :math:`\mathbf{M}` is shown in between. @@ -61,7 +61,7 @@ def plot1D_mat(a, b, M, title=''): def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): - """ Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values + r""" Plot matrix :math:`\mathbf{G}` in 2D with lines using alpha values Plot lines between source and target 2D samples with a color proportional to the value of the matrix :math:`\mathbf{G}` between samples. |