diff options
author | Nathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com> | 2022-03-24 10:53:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-03-24 10:53:47 +0100 |
commit | 767171593f2a98a26b9a39bf110a45085e3b982e (patch) | |
tree | 4eb4bcc657efc53a65c3fb4439bd0e0e106b6745 /ot/utils.py | |
parent | 9b9d2221d257f40ea3eb58b279b30d69162d62bb (diff) |
[MRG] Domain adaptation and unbalanced solvers with backend support (#343)
* First draft
* Add matrix inverse and square root to backend
* Eigen decomposition for older versions of pytorch (1.8.1 and older)
* Corrected eigen decomposition for pytorch 1.8.1 and older
* Spectral theorem is a thing
* Optimization
* small optimization
* More functions converted
* pep8
* remove a warning and prepare torch meshgrid for future torch release (which will change default indexing)
* dots and pep8
* Meshgrid corrected for older version and prepared for future versions changes
* New backend functions
* Base transport
* LinearTransport
* All transport classes + pep8
* PR added to release file
* Jcpot barycenter test
* unbalanced with backend
* pep8
* bug solve
* test of domain adaptation with backends
* solve bug for tic toc & macos
* solving scipy deprecation warning
* solving scipy deprecation warning attempt2
* solving scipy deprecation warning attempt3
* A warning is triggered when a float->int conversion is detected
* bug solve
* docs
* release file updated
* Better handling of float->int conversion in EMD
* Corrected test for is_floating_point
* docs
* release file updated
* cupy does not allow implicit cast
* fromnumpy
* added test
* test da tf jax
* test unbalanced with no provided histogram
* using type_as argument in unif function correctly
* pep8
* transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix
Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
Diffstat (limited to 'ot/utils.py')
-rw-r--r-- | ot/utils.py | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/ot/utils.py b/ot/utils.py index 725ca00..a23ce7e 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist import sys import warnings from inspect import signature -from .backend import get_backend +from .backend import get_backend, Backend __time_tic_toc = time.time() @@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs): def laplacian(x): r"""Compute Laplacian matrix""" - L = np.diag(np.sum(x, axis=0)) - x + nx = get_backend(x) + L = nx.diag(nx.sum(x, axis=0)) - x return L @@ -136,7 +137,7 @@ def unif(n, type_as=None): return np.ones((n,)) / n else: nx = get_backend(type_as) - return nx.ones((n,)) / n + return nx.ones((n,), type_as=type_as) / n def clean_zeros(a, b, M): @@ -296,7 +297,8 @@ def cost_normalization(C, norm=None): def dots(*args): r""" dots function for multiple matrix multiply """ - return reduce(np.dot, args) + nx = get_backend(*args) + return reduce(nx.dot, args) def label_normalization(y, start=0): @@ -314,8 +316,9 @@ def label_normalization(y, start=0): y : array-like, shape (`n1`, ) The input vector of labels normalized according to given start value. """ + nx = get_backend(y) - diff = np.min(np.unique(y)) - start + diff = nx.min(nx.unique(y)) - start if diff != 0: y -= diff return y @@ -482,6 +485,19 @@ class BaseEstimator(object): arguments (no ``*args`` or ``**kwargs``). """ + nx: Backend = None + + def _get_backend(self, *arrays): + nx = get_backend( + *[input_ for input_ in arrays if input_ is not None] + ) + if nx.__name__ in ("jax", "tf"): + raise TypeError( + """JAX or TF arrays have been received but domain + adaptation does not support those backend.""") + self.nx = nx + return nx + @classmethod def _get_param_names(cls): r"""Get parameter names for the estimator""" |