summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ot/da.py22
-rw-r--r--ot/deprecation.py103
-rw-r--r--test/test_da.py5
3 files changed, 126 insertions, 4 deletions
diff --git a/ot/da.py b/ot/da.py
index 3ccb1b3..8fa1895 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -10,12 +10,14 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
+import warnings
+
from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel
from .optim import cg
from .optim import gcg
-import warnings
+from .deprecation import deprecated
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -632,6 +634,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
return G, L
+@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
+ "removed in 0.5"
+ "\n\tfor standard transport use class EMDTransport instead.")
class OTDA(object):
"""Class for domain adaptation with optimal transport as proposed in [5]
@@ -758,10 +763,15 @@ class OTDA(object):
self.M = np.log(1 + np.log(1 + self.M))
+@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class SinkhornTransport instead.")
class OTDA_sinkhorn(OTDA):
"""Class for domain adaptation with optimal transport with entropic
- regularization"""
+ regularization
+
+
+ """
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
"""Fit regularized domain adaptation between samples is xs and xt
@@ -783,6 +793,8 @@ class OTDA_sinkhorn(OTDA):
self.computed = True
+@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class SinkhornLpl1Transport instead.")
class OTDA_lpl1(OTDA):
"""Class for domain adaptation with optimal transport with entropic and
@@ -810,6 +822,8 @@ class OTDA_lpl1(OTDA):
self.computed = True
+@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class SinkhornL1l2Transport instead.")
class OTDA_l1l2(OTDA):
"""Class for domain adaptation with optimal transport with entropic
@@ -837,6 +851,8 @@ class OTDA_l1l2(OTDA):
self.computed = True
+@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class MappingTransport instead.")
class OTDA_mapping_linear(OTDA):
"""Class for optimal transport with joint linear mapping estimation as in
@@ -882,6 +898,8 @@ class OTDA_mapping_linear(OTDA):
return None
+@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be"
+ " removed in 0.5 \nUse class MappingTransport instead.")
class OTDA_mapping_kernel(OTDA_mapping_linear):
"""Class for optimal transport with joint nonlinear mapping
diff --git a/ot/deprecation.py b/ot/deprecation.py
new file mode 100644
index 0000000..2b16427
--- /dev/null
+++ b/ot/deprecation.py
@@ -0,0 +1,103 @@
+"""
+ deprecated class from scikit-learn package
+ https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/deprecation.py
+"""
+
+import sys
+import warnings
+
+__all__ = ["deprecated", ]
+
+
+class deprecated(object):
+ """Decorator to mark a function or class as deprecated.
+ Issue a warning when the function is called/the class is instantiated and
+ adds a warning to the docstring.
+ The optional extra argument will be appended to the deprecation message
+ and the docstring. Note: to use this with the default value for extra, put
+ in an empty of parentheses:
+ >>> from ot.deprecation import deprecated
+ >>> @deprecated()
+ ... def some_function(): pass
+
+ Parameters
+ ----------
+ extra : string
+ to be added to the deprecation messages
+ """
+
+ # Adapted from http://wiki.python.org/moin/PythonDecoratorLibrary,
+ # but with many changes.
+
+ def __init__(self, extra=''):
+ self.extra = extra
+
+ def __call__(self, obj):
+ """Call method
+ Parameters
+ ----------
+ obj : object
+ """
+ if isinstance(obj, type):
+ return self._decorate_class(obj)
+ else:
+ return self._decorate_fun(obj)
+
+ def _decorate_class(self, cls):
+ msg = "Class %s is deprecated" % cls.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ # FIXME: we should probably reset __new__ for full generality
+ init = cls.__init__
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return init(*args, **kwargs)
+
+ cls.__init__ = wrapped
+
+ wrapped.__name__ = '__init__'
+ wrapped.__doc__ = self._update_doc(init.__doc__)
+ wrapped.deprecated_original = init
+
+ return cls
+
+ def _decorate_fun(self, fun):
+ """Decorate function fun"""
+
+ msg = "Function %s is deprecated" % fun.__name__
+ if self.extra:
+ msg += "; %s" % self.extra
+
+ def wrapped(*args, **kwargs):
+ warnings.warn(msg, category=DeprecationWarning)
+ return fun(*args, **kwargs)
+
+ wrapped.__name__ = fun.__name__
+ wrapped.__dict__ = fun.__dict__
+ wrapped.__doc__ = self._update_doc(fun.__doc__)
+
+ return wrapped
+
+ def _update_doc(self, olddoc):
+ newdoc = "DEPRECATED"
+ if self.extra:
+ newdoc = "%s: %s" % (newdoc, self.extra)
+ if olddoc:
+ newdoc = "%s\n\n%s" % (newdoc, olddoc)
+ return newdoc
+
+
+def _is_deprecated(func):
+ """Helper to check if func is wraped by our deprecated decorator"""
+ if sys.version_info < (3, 5):
+ raise NotImplementedError("This is only available for python3.5 "
+ "or above")
+ closures = getattr(func, '__closure__', [])
+ if closures is None:
+ closures = []
+ is_deprecated = ('deprecated' in ''.join([c.cell_contents
+ for c in closures
+ if isinstance(c.cell_contents, str)]))
+ return is_deprecated
diff --git a/test/test_da.py b/test/test_da.py
index 162f681..9578b3d 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -432,10 +432,11 @@ def test_otda():
da_emd.predict(xs) # interpolation of source samples
-if __name__ == "__main__":
+# if __name__ == "__main__":
+ # test_otda()
# test_sinkhorn_transport_class()
# test_emd_transport_class()
# test_sinkhorn_l1l2_transport_class()
# test_sinkhorn_lpl1_transport_class()
- test_mapping_transport_class()
+ # test_mapping_transport_class()