From fc58f39fc730a9e1bb2215ef063e37c50f0ebc1f Mon Sep 17 00:00:00 2001 From: Slasnista Date: Wed, 23 Aug 2017 15:09:08 +0200 Subject: added deprecation warning on old classes --- ot/da.py | 22 ++++++++++-- ot/deprecation.py | 103 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ test/test_da.py | 5 +-- 3 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 ot/deprecation.py diff --git a/ot/da.py b/ot/da.py index 3ccb1b3..8fa1895 100644 --- a/ot/da.py +++ b/ot/da.py @@ -10,12 +10,14 @@ Domain adaptation with optimal transport # License: MIT License import numpy as np +import warnings + from .bregman import sinkhorn from .lp import emd from .utils import unif, dist, kernel from .optim import cg from .optim import gcg -import warnings +from .deprecation import deprecated def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, @@ -632,6 +634,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', return G, L +@deprecated("The class OTDA is deprecated in 0.3.1 and will be " + "removed in 0.5" + "\n\tfor standard transport use class EMDTransport instead.") class OTDA(object): """Class for domain adaptation with optimal transport as proposed in [5] @@ -758,10 +763,15 @@ class OTDA(object): self.M = np.log(1 + np.log(1 + self.M)) +@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be" + " removed in 0.5 \nUse class SinkhornTransport instead.") class OTDA_sinkhorn(OTDA): """Class for domain adaptation with optimal transport with entropic - regularization""" + regularization + + + """ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs): """Fit regularized domain adaptation between samples is xs and xt @@ -783,6 +793,8 @@ class OTDA_sinkhorn(OTDA): self.computed = True +@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be" + " removed in 0.5 \nUse class SinkhornLpl1Transport instead.") class OTDA_lpl1(OTDA): """Class for domain adaptation with optimal transport with entropic and @@ -810,6 +822,8 @@ class OTDA_lpl1(OTDA): self.computed = True +@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be" + " removed in 0.5 \nUse class SinkhornL1l2Transport instead.") class OTDA_l1l2(OTDA): """Class for domain adaptation with optimal transport with entropic @@ -837,6 +851,8 @@ class OTDA_l1l2(OTDA): self.computed = True +@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be" + " removed in 0.5 \nUse class MappingTransport instead.") class OTDA_mapping_linear(OTDA): """Class for optimal transport with joint linear mapping estimation as in @@ -882,6 +898,8 @@ class OTDA_mapping_linear(OTDA): return None +@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be" + " removed in 0.5 \nUse class MappingTransport instead.") class OTDA_mapping_kernel(OTDA_mapping_linear): """Class for optimal transport with joint nonlinear mapping diff --git a/ot/deprecation.py b/ot/deprecation.py new file mode 100644 index 0000000..2b16427 --- /dev/null +++ b/ot/deprecation.py @@ -0,0 +1,103 @@ +""" + deprecated class from scikit-learn package + https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py +""" + +import sys +import warnings + +__all__ = ["deprecated", ] + + +class deprecated(object): + """Decorator to mark a function or class as deprecated. + Issue a warning when the function is called/the class is instantiated and + adds a warning to the docstring. + The optional extra argument will be appended to the deprecation message + and the docstring. Note: to use this with the default value for extra, put + in an empty of parentheses: + >>> from ot.deprecation import deprecated + >>> @deprecated() + ... def some_function(): pass + + Parameters + ---------- + extra : string + to be added to the deprecation messages + """ + + # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary, + # but with many changes. + + def __init__(self, extra=''): + self.extra = extra + + def __call__(self, obj): + """Call method + Parameters + ---------- + obj : object + """ + if isinstance(obj, type): + return self._decorate_class(obj) + else: + return self._decorate_fun(obj) + + def _decorate_class(self, cls): + msg = "Class %s is deprecated" % cls.__name__ + if self.extra: + msg += "; %s" % self.extra + + # FIXME: we should probably reset __new__ for full generality + init = cls.__init__ + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return init(*args, **kwargs) + + cls.__init__ = wrapped + + wrapped.__name__ = '__init__' + wrapped.__doc__ = self._update_doc(init.__doc__) + wrapped.deprecated_original = init + + return cls + + def _decorate_fun(self, fun): + """Decorate function fun""" + + msg = "Function %s is deprecated" % fun.__name__ + if self.extra: + msg += "; %s" % self.extra + + def wrapped(*args, **kwargs): + warnings.warn(msg, category=DeprecationWarning) + return fun(*args, **kwargs) + + wrapped.__name__ = fun.__name__ + wrapped.__dict__ = fun.__dict__ + wrapped.__doc__ = self._update_doc(fun.__doc__) + + return wrapped + + def _update_doc(self, olddoc): + newdoc = "DEPRECATED" + if self.extra: + newdoc = "%s: %s" % (newdoc, self.extra) + if olddoc: + newdoc = "%s\n\n%s" % (newdoc, olddoc) + return newdoc + + +def _is_deprecated(func): + """Helper to check if func is wraped by our deprecated decorator""" + if sys.version_info < (3, 5): + raise NotImplementedError("This is only available for python3.5 " + "or above") + closures = getattr(func, '__closure__', []) + if closures is None: + closures = [] + is_deprecated = ('deprecated' in ''.join([c.cell_contents + for c in closures + if isinstance(c.cell_contents, str)])) + return is_deprecated diff --git a/test/test_da.py b/test/test_da.py index 162f681..9578b3d 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -432,10 +432,11 @@ def test_otda(): da_emd.predict(xs) # interpolation of source samples -if __name__ == "__main__": +# if __name__ == "__main__": + # test_otda() # test_sinkhorn_transport_class() # test_emd_transport_class() # test_sinkhorn_l1l2_transport_class() # test_sinkhorn_lpl1_transport_class() - test_mapping_transport_class() + # test_mapping_transport_class() -- cgit v1.2.3