summaryrefslogtreecommitdiff
path: root/ot/da.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2018-05-29 16:25:22 +0200
committerRémi Flamary <remi.flamary@gmail.com>2018-05-29 16:25:22 +0200
commit4641ec5f2ddbff1a468afaf65741aecae44738cc (patch)
tree001341734398607bf64c8101c0af002c9f62a5d7 /ot/da.py
parent90efa5a8b189214d1aeb81920b2bb04ce0c261ca (diff)
remove OTDA + pep8
Diffstat (limited to 'ot/da.py')
-rw-r--r--ot/da.py284
1 files changed, 1 insertions, 283 deletions
diff --git a/ot/da.py b/ot/da.py
index 48b418f..bc09e3c 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -15,7 +15,7 @@ import scipy.linalg as linalg
from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization
-from .utils import check_params, deprecated, BaseEstimator
+from .utils import check_params, BaseEstimator
from .optim import cg
from .optim import gcg
@@ -740,288 +740,6 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
return A, b
-@deprecated("The class OTDA is deprecated in 0.3.1 and will be "
- "removed in 0.5"
- "\n\tfor standard transport use class EMDTransport instead.")
-class OTDA(object):
-
- """Class for domain adaptation with optimal transport as proposed in [5]
-
-
- References
- ----------
-
- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
- "Optimal Transport for Domain Adaptation," in IEEE Transactions on
- Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
-
- """
-
- def __init__(self, metric='sqeuclidean', norm=None):
- """ Class initialization"""
- self.xs = 0
- self.xt = 0
- self.G = 0
- self.metric = metric
- self.norm = norm
- self.computed = False
-
- def fit(self, xs, xt, ws=None, wt=None, max_iter=100000):
- """Fit domain adaptation between samples is xs and xt
- (with optional weights)"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = emd(ws, wt, self.M, max_iter)
- self.computed = True
-
- def interp(self, direction=1):
- """Barycentric interpolation for the source (1) or target (-1) samples
-
- This Barycentric interpolation solves for each source (resp target)
- sample xs (resp xt) the following optimization problem:
-
- .. math::
- arg\min_x \sum_i \gamma_{k,i} c(x,x_i^t)
-
- where k is the index of the sample in xs
-
- For the moment only squared euclidean distance is provided but more
- metric could be used in the future.
-
- """
- if direction > 0: # >0 then source to target
- G = self.G
- w = self.ws.reshape((self.xs.shape[0], 1))
- x = self.xt
- else:
- G = self.G.T
- w = self.wt.reshape((self.xt.shape[0], 1))
- x = self.xs
-
- if self.computed:
- if self.metric == 'sqeuclidean':
- return np.dot(G / w, x) # weighted mean
- else:
- print(
- "Warning, metric not handled yet, using weighted average")
- return np.dot(G / w, x) # weighted mean
- return None
- else:
- print("Warning, model not fitted yet, returning None")
- return None
-
- def predict(self, x, direction=1):
- """ Out of sample mapping using the formulation from [6]
-
- For each sample x to map, it finds the nearest source sample xs and
- map the samle x to the position xst+(x-xs) wher xst is the barycentric
- interpolation of source sample xs.
-
- References
- ----------
-
- .. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
- Regularized discrete optimal transport. SIAM Journal on Imaging
- Sciences, 7(3), 1853-1882.
-
- """
- if direction > 0: # >0 then source to target
- xf = self.xt
- x0 = self.xs
- else:
- xf = self.xs
- x0 = self.xt
-
- D0 = dist(x, x0) # dist netween new samples an source
- idx = np.argmin(D0, 1) # closest one
- xf = self.interp(direction) # interp the source samples
- # aply the delta to the interpolation
- return xf[idx, :] + x - x0[idx, :]
-
-
-@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class SinkhornTransport instead.")
-class OTDA_sinkhorn(OTDA):
-
- """Class for domain adaptation with optimal transport with entropic
- regularization
-
-
- """
-
- def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs):
- """Fit regularized domain adaptation between samples is xs and xt
- (with optional weights)"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
- self.computed = True
-
-
-@deprecated("The class OTDA_lpl1 is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class SinkhornLpl1Transport instead.")
-class OTDA_lpl1(OTDA):
-
- """Class for domain adaptation with optimal transport with entropic and
- group regularization"""
-
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
- """Fit regularized domain adaptation between samples is xs and xt
- (with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
- parameters"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
- self.computed = True
-
-
-@deprecated("The class OTDA_l1L2 is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class SinkhornL1l2Transport instead.")
-class OTDA_l1l2(OTDA):
-
- """Class for domain adaptation with optimal transport with entropic
- and group lasso regularization"""
-
- def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
- """Fit regularized domain adaptation between samples is xs and xt
- (with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
- parameters"""
- self.xs = xs
- self.xt = xt
-
- if wt is None:
- wt = unif(xt.shape[0])
- if ws is None:
- ws = unif(xs.shape[0])
-
- self.ws = ws
- self.wt = wt
-
- self.M = dist(xs, xt, metric=self.metric)
- self.M = cost_normalization(self.M, self.norm)
- self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
- self.computed = True
-
-
-@deprecated("The class OTDA_mapping_linear is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class MappingTransport instead.")
-class OTDA_mapping_linear(OTDA):
-
- """Class for optimal transport with joint linear mapping estimation as in
- [8]
- """
-
- def __init__(self):
- """ Class initialization"""
-
- self.xs = 0
- self.xt = 0
- self.G = 0
- self.L = 0
- self.bias = False
- self.computed = False
- self.metric = 'sqeuclidean'
-
- def fit(self, xs, xt, mu=1, eta=1, bias=False, **kwargs):
- """ Fit domain adaptation between samples is xs and xt (with optional
- weights)"""
- self.xs = xs
- self.xt = xt
- self.bias = bias
-
- self.ws = unif(xs.shape[0])
- self.wt = unif(xt.shape[0])
-
- self.G, self.L = joint_OT_mapping_linear(
- xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
- self.computed = True
-
- def mapping(self):
- return lambda x: self.predict(x)
-
- def predict(self, x):
- """ Out of sample mapping estimated during the call to fit"""
- if self.computed:
- if self.bias:
- x = np.hstack((x, np.ones((x.shape[0], 1))))
- return x.dot(self.L) # aply the delta to the interpolation
- else:
- print("Warning, model not fitted yet, returning None")
- return None
-
-
-@deprecated("The class OTDA_mapping_kernel is deprecated in 0.3.1 and will be"
- " removed in 0.5 \nUse class MappingTransport instead.")
-class OTDA_mapping_kernel(OTDA_mapping_linear):
-
- """Class for optimal transport with joint nonlinear mapping
- estimation as in [8]"""
-
- def fit(self, xs, xt, mu=1, eta=1, bias=False, kerneltype='gaussian',
- sigma=1, **kwargs):
- """ Fit domain adaptation between samples is xs and xt """
- self.xs = xs
- self.xt = xt
- self.bias = bias
-
- self.ws = unif(xs.shape[0])
- self.wt = unif(xt.shape[0])
- self.kernel = kerneltype
- self.sigma = sigma
- self.kwargs = kwargs
-
- self.G, self.L = joint_OT_mapping_kernel(
- xs, xt, mu=mu, eta=eta, bias=bias, **kwargs)
- self.computed = True
-
- def predict(self, x):
- """ Out of sample mapping estimated during the call to fit"""
-
- if self.computed:
- K = kernel(
- x, self.xs, method=self.kernel, sigma=self.sigma,
- **self.kwargs)
- if self.bias:
- K = np.hstack((K, np.ones((x.shape[0], 1))))
- return K.dot(self.L)
- else:
- print("Warning, model not fitted yet, returning None")
- return None
-
-
def distribution_estimation_uniform(X):
"""estimates a uniform distribution from an array of samples X