summaryrefslogtreecommitdiff
path: root/ot/utils.py
diff options
context:
space:
mode:
authorAntoine Rolet <antoine.rolet@gmail.com>2017-09-05 15:30:50 +0900
committerAntoine Rolet <antoine.rolet@gmail.com>2017-09-05 15:30:50 +0900
commit13dfb3ddbbd8926b4751b82dd41c5570253b1f07 (patch)
treeb28098e98640c64483a599103e2fdb5df46d2c79 /ot/utils.py
parent185eb3e2ef34b5ce6b8f90a28a5bcc78432b7fd3 (diff)
parent16697047eff9326a0ecb483317c13a854a3d3a71 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'ot/utils.py')
-rw-r--r--ot/utils.py346
1 files changed, 317 insertions, 29 deletions
diff --git a/ot/utils.py b/ot/utils.py
index 7ad7637..31a002b 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -2,36 +2,50 @@
"""
Various function that can be usefull
"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+#
+# License: MIT License
+
+import multiprocessing
+from functools import reduce
+import time
+
import numpy as np
from scipy.spatial.distance import cdist
-import multiprocessing
+import sys
+import warnings
+
+
+__time_tic_toc = time.time()
-import time
-__time_tic_toc=time.time()
def tic():
""" Python implementation of Matlab tic() function """
global __time_tic_toc
- __time_tic_toc=time.time()
+ __time_tic_toc = time.time()
+
def toc(message='Elapsed time : {} s'):
""" Python implementation of Matlab toc() function """
- t=time.time()
- print(message.format(t-__time_tic_toc))
- return t-__time_tic_toc
+ t = time.time()
+ print(message.format(t - __time_tic_toc))
+ return t - __time_tic_toc
+
def toq():
""" Python implementation of Julia toc() function """
- t=time.time()
- return t-__time_tic_toc
+ t = time.time()
+ return t - __time_tic_toc
-def kernel(x1,x2,method='gaussian',sigma=1,**kwargs):
+def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
"""Compute kernel matrix"""
- if method.lower() in ['gaussian','gauss','rbf']:
- K=np.exp(-dist(x1,x2)/(2*sigma**2))
+ if method.lower() in ['gaussian', 'gauss', 'rbf']:
+ K = np.exp(-dist(x1, x2) / (2 * sigma**2))
return K
+
def unif(n):
""" return a uniform histogram of length n (simplex)
@@ -48,17 +62,19 @@ def unif(n):
"""
- return np.ones((n,))/n
+ return np.ones((n,)) / n
+
-def clean_zeros(a,b,M):
- """ Remove all components with zeros weights in a and b
+def clean_zeros(a, b, M):
+ """ Remove all components with zeros weights in a and b
"""
- M2=M[a>0,:][:,b>0].copy() # copy force c style matrix (froemd)
- a2=a[a>0]
- b2=b[b>0]
- return a2,b2,M2
+ M2 = M[a > 0, :][:, b > 0].copy() # copy force c style matrix (froemd)
+ a2 = a[a > 0]
+ b2 = b[b > 0]
+ return a2, b2, M2
-def dist(x1,x2=None,metric='sqeuclidean'):
+
+def dist(x1, x2=None, metric='sqeuclidean'):
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
Parameters
@@ -84,12 +100,12 @@ def dist(x1,x2=None,metric='sqeuclidean'):
"""
if x2 is None:
- x2=x1
+ x2 = x1
- return cdist(x1,x2,metric=metric)
+ return cdist(x1, x2, metric=metric)
-def dist0(n,method='lin_square'):
+def dist0(n, method='lin_square'):
"""Compute standard cost matrices of size (n,n) for OT problems
Parameters
@@ -111,16 +127,50 @@ def dist0(n,method='lin_square'):
"""
- res=0
- if method=='lin_square':
- x=np.arange(n,dtype=np.float64).reshape((n,1))
- res=dist(x,x)
+ res = 0
+ if method == 'lin_square':
+ x = np.arange(n, dtype=np.float64).reshape((n, 1))
+ res = dist(x, x)
return res
+def cost_normalization(C, norm=None):
+ """ Apply normalization to the loss matrix
+
+
+ Parameters
+ ----------
+ C : np.array (n1, n2)
+ The cost matrix to normalize.
+ norm : str
+ type of normalization from 'median','max','log','loglog'. Any other
+ value do not normalize.
+
+
+ Returns
+ -------
+
+ C : np.array (n1, n2)
+ The input cost matrix normalized according to given norm.
+
+ """
+
+ if norm == "median":
+ C /= float(np.median(C))
+ elif norm == "max":
+ C /= float(np.max(C))
+ elif norm == "log":
+ C = np.log(1 + C)
+ elif norm == "loglog":
+ C = np.log(1 + np.log(1 + C))
+
+ return C
+
+
def dots(*args):
""" dots function for multiple matrix multiply """
- return reduce(np.dot,args)
+ return reduce(np.dot, args)
+
def fun(f, q_in, q_out):
""" Utility function for parmap with no serializing problems """
@@ -130,6 +180,7 @@ def fun(f, q_in, q_out):
break
q_out.put((i, f(x)))
+
def parmap(f, X, nprocs=multiprocessing.cpu_count()):
""" paralell map for multiprocessing """
q_in = multiprocessing.Queue(1)
@@ -147,4 +198,241 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()):
[p.join() for p in proc]
- return [x for i, x in sorted(res)] \ No newline at end of file
+ return [x for i, x in sorted(res)]
+
+
+def check_params(**kwargs):
+ """check_params: check whether some parameters are missing
+ """
+
+ missing_params = []
+ check = True
+
+ for param in kwargs:
+ if kwargs[param] is None:
+ missing_params.append(param)
+
+ if len(missing_params) > 0:
+ print("POT - Warning: following necessary parameters are missing")
+ for p in missing_params:
+ print("\n", p)
+
+ check = False
+
+ return check
+
+
+class deprecated(object):
+ """Decorator to mark a function or class as deprecated.
+
+ deprecated class from scikit-learn package
+ https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
+ Issue a warning when the function is called/the class is instantiated and
+ adds a warning to the docstring.
+ The optional extra argument will be appended to the deprecation message
+ and the docstring. Note: to use this with the default value for extra, put
+ in an empty of parentheses:
+ >>> from ot.deprecation import deprecated
+ >>> @deprecated()
+ ... def some_function(): pass
+
+ Parameters
+ ----------
+ extra : string
+ to be added to the deprecation messages
+ """
+
+ # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
+ # but with many changes.
+
+ def __init__(self, extra=''):
+ self.extra = extra
+
+ def __call__(self, obj):
+ """Call method
+ Parameters
+ ----------
+ obj : object
+ """
+ if isinstance(obj, type):
+ return self._decorate_class(obj)
+ else:
+ return self._decorate_fun(obj)
+
+ def _decorate_class(self, cls):
+ msg = "Class %s is deprecated" % cls.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ # FIXME: we should probably reset __new__ for full generality
+ init = cls.__init__
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return init(*args, **kwargs)
+
+ cls.__init__ = wrapped
+
+ wrapped.__name__ = '__init__'
+ wrapped.__doc__ = self._update_doc(init.__doc__)
+ wrapped.deprecated_original = init
+
+ return cls
+
+ def _decorate_fun(self, fun):
+ """Decorate function fun"""
+
+ msg = "Function %s is deprecated" % fun.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return fun(*args, **kwargs)
+
+ wrapped.__name__ = fun.__name__
+ wrapped.__dict__ = fun.__dict__
+ wrapped.__doc__ = self._update_doc(fun.__doc__)
+
+ return wrapped
+
+ def _update_doc(self, olddoc):
+ newdoc = "DEPRECATED"
+ if self.extra:
+ newdoc = "%s: %s" % (newdoc, self.extra)
+ if olddoc:
+ newdoc = "%s\n\n%s" % (newdoc, olddoc)
+ return newdoc
+
+
+def _is_deprecated(func):
+ """Helper to check if func is wraped by our deprecated decorator"""
+ if sys.version_info < (3, 5):
+ raise NotImplementedError("This is only available for python3.5 "
+ "or above")
+ closures = getattr(func, '__closure__', [])
+ if closures is None:
+ closures = []
+ is_deprecated = ('deprecated' in ''.join([c.cell_contents
+ for c in closures
+ if isinstance(c.cell_contents, str)]))
+ return is_deprecated
+
+
+class BaseEstimator(object):
+ """Base class for most objects in POT
+ adapted from sklearn BaseEstimator class
+
+ Notes
+ -----
+ All estimators should specify all the parameters that can be set
+ at the class level in their ``__init__`` as explicit keyword
+ arguments (no ``*args`` or ``**kwargs``).
+ """
+
+ @classmethod
+ def _get_param_names(cls):
+ """Get parameter names for the estimator"""
+ try:
+ from inspect import signature
+ except ImportError:
+ from .externals.funcsigs import signature
+ # fetch the constructor or the original constructor before
+ # deprecation wrapping if any
+ init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
+ if init is object.__init__:
+ # No explicit constructor to introspect
+ return []
+
+ # introspect the constructor arguments to find the model parameters
+ # to represent
+ init_signature = signature(init)
+ # Consider the constructor parameters excluding 'self'
+ parameters = [p for p in init_signature.parameters.values()
+ if p.name != 'self' and p.kind != p.VAR_KEYWORD]
+ for p in parameters:
+ if p.kind == p.VAR_POSITIONAL:
+ raise RuntimeError("POT estimators should always "
+ "specify their parameters in the signature"
+ " of their __init__ (no varargs)."
+ " %s with constructor %s doesn't "
+ " follow this convention."
+ % (cls, init_signature))
+ # Extract and sort argument names excluding 'self'
+ return sorted([p.name for p in parameters])
+
+ def get_params(self, deep=True):
+ """Get parameters for this estimator.
+
+ Parameters
+ ----------
+ deep : boolean, optional
+ If True, will return the parameters for this estimator and
+ contained subobjects that are estimators.
+
+ Returns
+ -------
+ params : mapping of string to any
+ Parameter names mapped to their values.
+ """
+ out = dict()
+ for key in self._get_param_names():
+ # We need deprecation warnings to always be on in order to
+ # catch deprecated param values.
+ # This is set in utils/__init__.py but it gets overwritten
+ # when running under python3 somehow.
+ warnings.simplefilter("always", DeprecationWarning)
+ try:
+ with warnings.catch_warnings(record=True) as w:
+ value = getattr(self, key, None)
+ if len(w) and w[0].category == DeprecationWarning:
+ # if the parameter is deprecated, don't show it
+ continue
+ finally:
+ warnings.filters.pop(0)
+
+ # XXX: should we rather test if instance of estimator?
+ if deep and hasattr(value, 'get_params'):
+ deep_items = value.get_params().items()
+ out.update((key + '__' + k, val) for k, val in deep_items)
+ out[key] = value
+ return out
+
+ def set_params(self, **params):
+ """Set the parameters of this estimator.
+
+ The method works on simple estimators as well as on nested objects
+ (such as pipelines). The latter have parameters of the form
+ ``<component>__<parameter>`` so that it's possible to update each
+ component of a nested object.
+
+ Returns
+ -------
+ self
+ """
+ if not params:
+ # Simple optimisation to gain speed (inspect is slow)
+ return self
+ valid_params = self.get_params(deep=True)
+ # for key, value in iteritems(params):
+ for key, value in params.items():
+ split = key.split('__', 1)
+ if len(split) > 1:
+ # nested objects case
+ name, sub_name = split
+ if name not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (name, self))
+ sub_object = valid_params[name]
+ sub_object.set_params(**{sub_name: value})
+ else:
+ # simple objects case
+ if key not in valid_params:
+ raise ValueError('Invalid parameter %s for estimator %s. '
+ 'Check the list of available parameters '
+ 'with `estimator.get_params().keys()`.' %
+ (key, self.__class__.__name__))
+ setattr(self, key, value)
+ return self