summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
committerGard Spreemann <gspr@nonempty.org>2022-04-27 11:49:23 +0200
commit35bd2c98b642df78638d7d733bc1a89d873db1de (patch)
tree6bc637624004713808d3097b95acdccbb9608e52 /ot/utils.py
parentc4753bd3f74139af8380127b66b484bc09b50661 (diff)
parenteccb1386eea52b94b82456d126bd20cbe3198e05 (diff)
Merge tag '0.8.2' into dfsg/latest
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py36
1 files changed, 29 insertions, 7 deletions
diff --git a/ot/utils.py b/ot/utils.py
index e6c93c8..a23ce7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend
+from .backend import get_backend, Backend
__time_tic_toc = time.time()
@@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
def laplacian(x):
r"""Compute Laplacian matrix"""
- L = np.diag(np.sum(x, axis=0)) - x
+ nx = get_backend(x)
+ L = nx.diag(nx.sum(x, axis=0)) - x
return L
@@ -116,7 +117,7 @@ def proj_simplex(v, z=1):
return w
-def unif(n):
+def unif(n, type_as=None):
r"""
Return a uniform histogram of length `n` (simplex).
@@ -124,13 +125,19 @@ def unif(n):
----------
n : int
number of bins in the histogram
+ type_as : array_like
+ array of the same type of the expected output (numpy/pytorch/jax)
Returns
-------
- h : np.array (`n`,)
+ h : array_like (`n`,)
histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
"""
- return np.ones((n,)) / n
+ if type_as is None:
+ return np.ones((n,)) / n
+ else:
+ nx = get_backend(type_as)
+ return nx.ones((n,), type_as=type_as) / n
def clean_zeros(a, b, M):
@@ -290,7 +297,8 @@ def cost_normalization(C, norm=None):
def dots(*args):
r""" dots function for multiple matrix multiply """
- return reduce(np.dot, args)
+ nx = get_backend(*args)
+ return reduce(nx.dot, args)
def label_normalization(y, start=0):
@@ -308,8 +316,9 @@ def label_normalization(y, start=0):
y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
+ nx = get_backend(y)
- diff = np.min(np.unique(y)) - start
+ diff = nx.min(nx.unique(y)) - start
if diff != 0:
y -= diff
return y
@@ -476,6 +485,19 @@ class BaseEstimator(object):
arguments (no ``*args`` or ``**kwargs``).
"""
+ nx: Backend = None
+
+ def _get_backend(self, *arrays):
+ nx = get_backend(
+ *[input_ for input_ in arrays if input_ is not None]
+ )
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError(
+ """JAX or TF arrays have been received but domain
+ adaptation does not support those backend.""")
+ self.nx = nx
+ return nx
+
@classmethod
def _get_param_names(cls):
r"""Get parameter names for the estimator"""