diff options
author | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
---|---|---|
committer | Gard Spreemann <gspr@nonempty.org> | 2022-04-27 11:49:23 +0200 |
commit | 35bd2c98b642df78638d7d733bc1a89d873db1de (patch) | |
tree | 6bc637624004713808d3097b95acdccbb9608e52 /ot/utils.py | |
parent | c4753bd3f74139af8380127b66b484bc09b50661 (diff) | |
parent | eccb1386eea52b94b82456d126bd20cbe3198e05 (diff) |
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/ot/utils.py b/ot/utils.py index e6c93c8..a23ce7e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature -from .backend import get_backend +from .backend import get_backend, Backend __time_tic_toc = time.time() @@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): def laplacian(x): r"""Compute Laplacian matrix""" - L = np.diag(np.sum(x, axis=0)) - x + nx = get_backend(x) + L = nx.diag(nx.sum(x, axis=0)) - x return L @@ -116,7 +117,7 @@ def proj_simplex(v, z=1): return w -def unif(n): +def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). @@ -124,13 +125,19 @@ def unif(n): ---------- n : int number of bins in the histogram + type_as : array_like + array of the same type of the expected output (numpy/pytorch/jax) Returns ------- - h : np.array (`n`,) + h : array_like (`n`,) histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}` """ - return np.ones((n,)) / n + if type_as is None: + return np.ones((n,)) / n + else: + nx = get_backend(type_as) + return nx.ones((n,), type_as=type_as) / n def clean_zeros(a, b, M): @@ -290,7 +297,8 @@ def cost_normalization(C, norm=None): def dots(*args): r""" dots function for multiple matrix multiply """ - return reduce(np.dot, args) + nx = get_backend(*args) + return reduce(nx.dot, args) def label_normalization(y, start=0): @@ -308,8 +316,9 @@ def label_normalization(y, start=0): y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ + nx = get_backend(y) - diff = np.min(np.unique(y)) - start + diff = nx.min(nx.unique(y)) - start if diff != 0: y -= diff return y @@ -476,6 +485,19 @@ class BaseEstimator(object): arguments (no ``*args`` or ``**kwargs``). """ + nx: Backend = None + + def _get_backend(self, *arrays): + nx = get_backend( + *[input_ for input_ in arrays if input_ is not None] + ) + if nx.__name__ in ("jax", "tf"): + raise TypeError( + """JAX or TF arrays have been received but domain + adaptation does not support those backend.""") + self.nx = nx + return nx + @classmethod def _get_param_names(cls): r"""Get parameter names for the estimator""" |