summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py26
1 files changed, 21 insertions, 5 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 725ca00..a23ce7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend
+from .backend import get_backend, Backend
__time_tic_toc = time.time()
@@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
def laplacian(x):
r"""Compute Laplacian matrix"""
- L = np.diag(np.sum(x, axis=0)) - x
+ nx = get_backend(x)
+ L = nx.diag(nx.sum(x, axis=0)) - x
return L
@@ -136,7 +137,7 @@ def unif(n, type_as=None):
return np.ones((n,)) / n
else:
nx = get_backend(type_as)
- return nx.ones((n,)) / n
+ return nx.ones((n,), type_as=type_as) / n
def clean_zeros(a, b, M):
@@ -296,7 +297,8 @@ def cost_normalization(C, norm=None):
def dots(*args):
r""" dots function for multiple matrix multiply """
- return reduce(np.dot, args)
+ nx = get_backend(*args)
+ return reduce(nx.dot, args)
def label_normalization(y, start=0):
@@ -314,8 +316,9 @@ def label_normalization(y, start=0):
y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
+ nx = get_backend(y)
- diff = np.min(np.unique(y)) - start
+ diff = nx.min(nx.unique(y)) - start
if diff != 0:
y -= diff
return y
@@ -482,6 +485,19 @@ class BaseEstimator(object):
arguments (no ``*args`` or ``**kwargs``).
"""
+ nx: Backend = None
+
+ def _get_backend(self, *arrays):
+ nx = get_backend(
+ *[input_ for input_ in arrays if input_ is not None]
+ )
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError(
+ """JAX or TF arrays have been received but domain
+ adaptation does not support those backend.""")
+ self.nx = nx
+ return nx
+
@classmethod
def _get_param_names(cls):
r"""Get parameter names for the estimator"""