summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py498
1 files changed, 498 insertions, 0 deletions
diff --git a/ot/utils.py b/ot/utils.py
new file mode 100644
index 0000000..b71458b
--- /dev/null
+++ b/ot/utils.py
@@ -0,0 +1,498 @@
+# -*- coding: utf-8 -*-
+"""
+Various useful functions
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import multiprocessing
+from functools import reduce
+import time
+
+import numpy as np
+from scipy.spatial.distance import cdist
+import sys
+import warnings
+try:
+ from inspect import signature
+except ImportError:
+ from .externals.funcsigs import signature
+
+__time_tic_toc = time.time()
+
+
+def tic():
+ """ Python implementation of Matlab tic() function """
+ global __time_tic_toc
+ __time_tic_toc = time.time()
+
+
+def toc(message='Elapsed time : {} s'):
+ """ Python implementation of Matlab toc() function """
+ t = time.time()
+ print(message.format(t - __time_tic_toc))
+ return t - __time_tic_toc
+
+
+def toq():
+ """ Python implementation of Julia toc() function """
+ t = time.time()
+ return t - __time_tic_toc
+
+
+def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
+ """Compute kernel matrix"""
+ if method.lower() in ['gaussian', 'gauss', 'rbf']:
+ K = np.exp(-dist(x1, x2) / (2 * sigma**2))
+ return K
+
+
+def unif(n):
+ """ return a uniform histogram of length n (simplex)
+
+ Parameters
+ ----------
+
+ n : int
+ number of bins in the histogram
+
+ Returns
+ -------
+ h : np.array (n,)
+ histogram of length n such that h_i=1/n for all i
+
+
+ """
+ return np.ones((n,)) / n
+
+
+def clean_zeros(a, b, M):
+ """ Remove all components with zeros weights in a and b
+ """
+ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
+ a2 = a[a > 0]
+ b2 = b[b > 0]
+ return a2, b2, M2
+
+
+def euclidean_distances(X, Y, squared=False):
+ """
+ Considering the rows of X (and Y=X) as vectors, compute the
+ distance matrix between each pair of vectors.
+ Parameters
+ ----------
+ X : {array-like}, shape (n_samples_1, n_features)
+ Y : {array-like}, shape (n_samples_2, n_features)
+ squared : boolean, optional
+ Return squared Euclidean distances.
+ Returns
+ -------
+ distances : {array}, shape (n_samples_1, n_samples_2)
+ """
+ XX = np.einsum('ij,ij->i', X, X)[:, np.newaxis]
+ YY = np.einsum('ij,ij->i', Y, Y)[np.newaxis, :]
+ distances = np.dot(X, Y.T)
+ distances *= -2
+ distances += XX
+ distances += YY
+ np.maximum(distances, 0, out=distances)
+ if X is Y:
+ # Ensure that distances between vectors and themselves are set to 0.0.
+ # This may not be the case due to floating point rounding errors.
+ distances.flat[::distances.shape[0] + 1] = 0.0
+ return distances if squared else np.sqrt(distances, out=distances)
+
+
+def dist(x1, x2=None, metric='sqeuclidean'):
+ """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
+
+ Parameters
+ ----------
+
+ x1 : ndarray, shape (n1,d)
+ matrix with n1 samples of size d
+ x2 : array, shape (n2,d), optional
+ matrix with n2 samples of size d (if None then x2=x1)
+ metric : str | callable, optional
+ Name of the metric to be computed (full list in the doc of scipy), If a string,
+ the distance function can be 'braycurtis', 'canberra', 'chebyshev', 'cityblock',
+ 'correlation', 'cosine', 'dice', 'euclidean', 'hamming', 'jaccard', 'kulsinski',
+ 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean',
+ 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'wminkowski', 'yule'.
+
+
+ Returns
+ -------
+
+ M : np.array (n1,n2)
+ distance matrix computed with given metric
+
+ """
+ if x2 is None:
+ x2 = x1
+ if metric == "sqeuclidean":
+ return euclidean_distances(x1, x2, squared=True)
+ return cdist(x1, x2, metric=metric)
+
+
+def dist0(n, method='lin_square'):
+ """Compute standard cost matrices of size (n, n) for OT problems
+
+ Parameters
+ ----------
+ n : int
+ Size of the cost matrix.
+ method : str, optional
+ Type of loss matrix chosen from:
+
+ * 'lin_square' : linear sampling between 0 and n-1, quadratic loss
+
+ Returns
+ -------
+ M : ndarray, shape (n1,n2)
+ Distance matrix computed with given metric.
+ """
+ res = 0
+ if method == 'lin_square':
+ x = np.arange(n, dtype=np.float64).reshape((n, 1))
+ res = dist(x, x)
+ return res
+
+
+def cost_normalization(C, norm=None):
+ """ Apply normalization to the loss matrix
+
+ Parameters
+ ----------
+ C : ndarray, shape (n1, n2)
+ The cost matrix to normalize.
+ norm : str
+ Type of normalization from 'median', 'max', 'log', 'loglog'. Any
+ other value do not normalize.
+
+ Returns
+ -------
+ C : ndarray, shape (n1, n2)
+ The input cost matrix normalized according to given norm.
+ """
+
+ if norm is None:
+ pass
+ elif norm == "median":
+ C /= float(np.median(C))
+ elif norm == "max":
+ C /= float(np.max(C))
+ elif norm == "log":
+ C = np.log(1 + C)
+ elif norm == "loglog":
+ C = np.log1p(np.log1p(C))
+ else:
+ raise ValueError('Norm %s is not a valid option.\n'
+ 'Valid options are:\n'
+ 'median, max, log, loglog' % norm)
+ return C
+
+
+def dots(*args):
+ """ dots function for multiple matrix multiply """
+ return reduce(np.dot, args)
+
+
+def fun(f, q_in, q_out):
+ """ Utility function for parmap with no serializing problems """
+ while True:
+ i, x = q_in.get()
+ if i is None:
+ break
+ q_out.put((i, f(x)))
+
+
+def parmap(f, X, nprocs=multiprocessing.cpu_count()):
+ """ paralell map for multiprocessing (only map on windows)"""
+
+ if not sys.platform.endswith('win32'):
+
+ q_in = multiprocessing.Queue(1)
+ q_out = multiprocessing.Queue()
+
+ proc = [multiprocessing.Process(target=fun, args=(f, q_in, q_out))
+ for _ in range(nprocs)]
+ for p in proc:
+ p.daemon = True
+ p.start()
+
+ sent = [q_in.put((i, x)) for i, x in enumerate(X)]
+ [q_in.put((None, None)) for _ in range(nprocs)]
+ res = [q_out.get() for _ in range(len(sent))]
+
+ [p.join() for p in proc]
+
+ return [x for i, x in sorted(res)]
+ else:
+ return list(map(f, X))
+
+
+def check_params(**kwargs):
+ """check_params: check whether some parameters are missing
+ """
+
+ missing_params = []
+ check = True
+
+ for param in kwargs:
+ if kwargs[param] is None:
+ missing_params.append(param)
+
+ if len(missing_params) > 0:
+ print("POT - Warning: following necessary parameters are missing")
+ for p in missing_params:
+ print("\n", p)
+
+ check = False
+
+ return check
+
+
+def check_random_state(seed):
+ """Turn seed into a np.random.RandomState instance
+
+ Parameters
+ ----------
+ seed : None | int | instance of RandomState
+ If seed is None, return the RandomState singleton used by np.random.
+ If seed is an int, return a new RandomState instance seeded with seed.
+ If seed is already a RandomState instance, return it.
+ Otherwise raise ValueError.
+ """
+ if seed is None or seed is np.random:
+ return np.random.mtrand._rand
+ if isinstance(seed, (int, np.integer)):
+ return np.random.RandomState(seed)
+ if isinstance(seed, np.random.RandomState):
+ return seed
+ raise ValueError('{} cannot be used to seed a numpy.random.RandomState'
+ ' instance'.format(seed))
+
+
+class deprecated(object):
+ """Decorator to mark a function or class as deprecated.
+
+ deprecated class from scikit-learn package
+ https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
+ Issue a warning when the function is called/the class is instantiated and
+ adds a warning to the docstring.
+ The optional extra argument will be appended to the deprecation message
+ and the docstring. Note: to use this with the default value for extra, put
+ in an empty of parentheses:
+ >>> from ot.deprecation import deprecated # doctest: +SKIP
+ >>> @deprecated() # doctest: +SKIP
+ ... def some_function(): pass # doctest: +SKIP
+
+ Parameters
+ ----------
+ extra : str
+ To be added to the deprecation messages.
+ """
+
+ # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
+ # but with many changes.
+
+ def __init__(self, extra=''):
+ self.extra = extra
+
+ def __call__(self, obj):
+ """Call method
+ Parameters
+ ----------
+ obj : object
+ """
+ if isinstance(obj, type):
+ return self._decorate_class(obj)
+ else:
+ return self._decorate_fun(obj)
+
+ def _decorate_class(self, cls):
+ msg = "Class %s is deprecated" % cls.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ # FIXME: we should probably reset __new__ for full generality
+ init = cls.__init__
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return init(*args, **kwargs)
+
+ cls.__init__ = wrapped
+
+ wrapped.__name__ = '__init__'
+ wrapped.__doc__ = self._update_doc(init.__doc__)
+ wrapped.deprecated_original = init
+
+ return cls
+
+ def _decorate_fun(self, fun):
+ """Decorate function fun"""
+
+ msg = "Function %s is deprecated" % fun.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return fun(*args, **kwargs)
+
+ wrapped.__name__ = fun.__name__
+ wrapped.__dict__ = fun.__dict__
+ wrapped.__doc__ = self._update_doc(fun.__doc__)
+
+ return wrapped
+
+ def _update_doc(self, olddoc):
+ newdoc = "DEPRECATED"
+ if self.extra:
+ newdoc = "%s: %s" % (newdoc, self.extra)
+ if olddoc:
+ newdoc = "%s\n\n%s" % (newdoc, olddoc)
+ return newdoc
+
+
+def _is_deprecated(func):
+ """Helper to check if func is wraped by our deprecated decorator"""
+ if sys.version_info < (3, 5):
+ raise NotImplementedError("This is only available for python3.5 "
+ "or above")
+ closures = getattr(func, '__closure__', [])
+ if closures is None:
+ closures = []
+ is_deprecated = ('deprecated' in ''.join([c.cell_contents
+ for c in closures
+ if isinstance(c.cell_contents, str)]))
+ return is_deprecated
+
+
+class BaseEstimator(object):
+ """Base class for most objects in POT
+
+ Code adapted from sklearn BaseEstimator class
+
+ Notes
+ -----
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
+ """
+
+ @classmethod
+ def _get_param_names(cls):
+ """Get parameter names for the estimator"""
+
+ # fetch the constructor or the original constructor before
+ # deprecation wrapping if any
+ init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
+ if init is object.__init__:
+ # No explicit constructor to introspect
+ return []
+
+ # introspect the constructor arguments to find the model parameters
+ # to represent
+ init_signature = signature(init)
+ # Consider the constructor parameters excluding 'self'
+ parameters = [p for p in init_signature.parameters.values()
+ if p.name != 'self' and p.kind != p.VAR_KEYWORD]
+ for p in parameters:
+ if p.kind == p.VAR_POSITIONAL:
+ raise RuntimeError("POT estimators should always "
+ "specify their parameters in the signature"
+ " of their __init__ (no varargs)."
+ " %s with constructor %s doesn't "
+ " follow this convention."
+ % (cls, init_signature))
+ # Extract and sort argument names excluding 'self'
+ return sorted([p.name for p in parameters])
+
+ def get_params(self, deep=True):
+ """Get parameters for this estimator.
+
+ Parameters
+ ----------
+ deep : bool, optional
+ If True, will return the parameters for this estimator and
+ contained subobjects that are estimators.
+
+ Returns
+ -------
+ params : mapping of string to any
+ Parameter names mapped to their values.
+ """
+ out = dict()
+ for key in self._get_param_names():
+ # We need deprecation warnings to always be on in order to
+ # catch deprecated param values.
+ # This is set in utils/__init__.py but it gets overwritten
+ # when running under python3 somehow.
+ warnings.simplefilter("always", DeprecationWarning)
+ try:
+ with warnings.catch_warnings(record=True) as w:
+ value = getattr(self, key, None)
+ if len(w) and w[0].category == DeprecationWarning:
+ # if the parameter is deprecated, don't show it
+ continue
+ finally:
+ warnings.filters.pop(0)
+
+ # XXX: should we rather test if instance of estimator?
+ if deep and hasattr(value, 'get_params'):
+ deep_items = value.get_params().items()
+ out.update((key + '__' + k, val) for k, val in deep_items)
+ out[key] = value
+ return out
+
+ def set_params(self, **params):
+ """Set the parameters of this estimator.
+
+ The method works on simple estimators as well as on nested objects
+ (such as pipelines). The latter have parameters of the form
+ ``<component>__<parameter>`` so that it's possible to update each
+ component of a nested object.
+
+ Returns
+ -------
+ self
+ """
+ if not params:
+ # Simple optimisation to gain speed (inspect is slow)
+ return self
+ valid_params = self.get_params(deep=True)
+ # for key, value in iteritems(params):
+ for key, value in params.items():
+ split = key.split('__', 1)
+ if len(split) > 1:
+ # nested objects case
+ name, sub_name = split
+ if name not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (name, self))
+ sub_object = valid_params[name]
+ sub_object.set_params(**{sub_name: value})
+ else:
+ # simple objects case
+ if key not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (key, self.__class__.__name__))
+ setattr(self, key, value)
+ return self
+
+
+class UndefinedParameter(Exception):
+ """
+ Aim at raising an Exception when a undefined parameter is called
+
+ """
+ pass