summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/plot_otda_jcpot.py171
-rw-r--r--ot/bregman.py160
-rw-r--r--ot/da.py181
-rw-r--r--test/test_da.py56
4 files changed, 3 insertions, 565 deletions
diff --git a/examples/plot_otda_jcpot.py b/examples/plot_otda_jcpot.py
deleted file mode 100644
index 316fa8b..0000000
--- a/examples/plot_otda_jcpot.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-========================
-OT for multi-source target shift
-========================
-
-This example introduces a target shift problem with two 2D source and 1 target domain.
-
-"""
-
-# Authors: Remi Flamary <remi.flamary@unice.fr>
-# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
-#
-# License: MIT License
-
-import pylab as pl
-import numpy as np
-import ot
-from ot.datasets import make_data_classif
-
-##############################################################################
-# Generate data
-# -------------
-n = 50
-sigma = 0.3
-np.random.seed(1985)
-
-p1 = .2
-dec1 = [0, 2]
-
-p2 = .9
-dec2 = [0, -2]
-
-pt = .4
-dect = [4, 0]
-
-xs1, ys1 = make_data_classif('2gauss_prop', n, nz=sigma, p=p1, bias=dec1)
-xs2, ys2 = make_data_classif('2gauss_prop', n + 1, nz=sigma, p=p2, bias=dec2)
-xt, yt = make_data_classif('2gauss_prop', n, nz=sigma, p=pt, bias=dect)
-
-all_Xr = [xs1, xs2]
-all_Yr = [ys1, ys2]
-# %%
-
-da = 1.5
-
-
-def plot_ax(dec, name):
- pl.plot([dec[0], dec[0]], [dec[1] - da, dec[1] + da], 'k', alpha=0.5)
- pl.plot([dec[0] - da, dec[0] + da], [dec[1], dec[1]], 'k', alpha=0.5)
- pl.text(dec[0] - .5, dec[1] + 2, name)
-
-
-##############################################################################
-# Fig 1 : plots source and target samples
-# ---------------------------------------
-
-pl.figure(1)
-pl.clf()
-plot_ax(dec1, 'Source 1')
-plot_ax(dec2, 'Source 2')
-plot_ax(dect, 'Target')
-pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9,
- label='Source 1 ({:1.2f}, {:1.2f})'.format(1 - p1, p1))
-pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9,
- label='Source 2 ({:1.2f}, {:1.2f})'.format(1 - p2, p2))
-pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9,
- label='Target ({:1.2f}, {:1.2f})'.format(1 - pt, pt))
-pl.title('Data')
-
-pl.legend()
-pl.axis('equal')
-pl.axis('off')
-
-##############################################################################
-# Instantiate Sinkhorn transport algorithm and fit them for all source domains
-# ----------------------------------------------------------------------------
-ot_sinkhorn = ot.da.SinkhornTransport(reg_e=1e-1, metric='sqeuclidean')
-
-
-def print_G(G, xs, ys, xt):
- for i in range(G.shape[0]):
- for j in range(G.shape[1]):
- if G[i, j] > 5e-4:
- if ys[i]:
- c = 'b'
- else:
- c = 'r'
- pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], c, alpha=.2)
-
-
-##############################################################################
-# Fig 2 : plot optimal couplings and transported samples
-# ------------------------------------------------------
-pl.figure(2)
-pl.clf()
-plot_ax(dec1, 'Source 1')
-plot_ax(dec2, 'Source 2')
-plot_ax(dect, 'Target')
-print_G(ot_sinkhorn.fit(Xs=xs1, Xt=xt).coupling_, xs1, ys1, xt)
-print_G(ot_sinkhorn.fit(Xs=xs2, Xt=xt).coupling_, xs2, ys2, xt)
-pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
-pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
-pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
-
-pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
-pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
-
-pl.title('Independent OT')
-
-pl.legend()
-pl.axis('equal')
-pl.axis('off')
-
-##############################################################################
-# Instantiate JCPOT adaptation algorithm and fit it
-# ----------------------------------------------------------------------------
-otda = ot.da.JCPOTTransport(reg_e=1e-2, max_iter=1000, metric='sqeuclidean', tol=1e-9, verbose=True, log=True)
-otda.fit(all_Xr, all_Yr, xt)
-
-ws1 = otda.proportions_.dot(otda.log_['D2'][0])
-ws2 = otda.proportions_.dot(otda.log_['D2'][1])
-
-pl.figure(3)
-pl.clf()
-plot_ax(dec1, 'Source 1')
-plot_ax(dec2, 'Source 2')
-plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
-pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
-pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
-pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
-
-pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
-pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
-
-pl.title('OT with prop estimation ({:1.3f},{:1.3f})'.format(otda.proportions_[0], otda.proportions_[1]))
-
-pl.legend()
-pl.axis('equal')
-pl.axis('off')
-
-##############################################################################
-# Run oracle transport algorithm with known proportions
-# ----------------------------------------------------------------------------
-h_res = np.array([1 - pt, pt])
-
-ws1 = h_res.dot(otda.log_['D2'][0])
-ws2 = h_res.dot(otda.log_['D2'][1])
-
-pl.figure(4)
-pl.clf()
-plot_ax(dec1, 'Source 1')
-plot_ax(dec2, 'Source 2')
-plot_ax(dect, 'Target')
-print_G(ot.bregman.sinkhorn(ws1, [], otda.log_['M'][0], reg=1e-2), xs1, ys1, xt)
-print_G(ot.bregman.sinkhorn(ws2, [], otda.log_['M'][1], reg=1e-2), xs2, ys2, xt)
-pl.scatter(xs1[:, 0], xs1[:, 1], c=ys1, s=35, marker='x', cmap='Set1', vmax=9)
-pl.scatter(xs2[:, 0], xs2[:, 1], c=ys2, s=35, marker='+', cmap='Set1', vmax=9)
-pl.scatter(xt[:, 0], xt[:, 1], c=yt, s=35, marker='o', cmap='Set1', vmax=9)
-
-pl.plot([], [], 'r', alpha=.2, label='Mass from Class 1')
-pl.plot([], [], 'b', alpha=.2, label='Mass from Class 2')
-
-pl.title('OT with known proportion ({:1.1f},{:1.1f})'.format(h_res[0], h_res[1]))
-
-pl.legend()
-pl.axis('equal')
-pl.axis('off')
-pl.show()
diff --git a/ot/bregman.py b/ot/bregman.py
index 61dfa52..f737e81 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -1503,166 +1503,6 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
return np.sum(K0, axis=1)
-def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
- stopThr=1e-6, verbose=False, log=False, **kwargs):
- r'''Joint OT and proportion estimation for multi-source target shift as proposed in [27]
-
- The function solves the following optimization problem:
-
- .. math::
-
- \mathbf{h} = arg\min_{\mathbf{h}}\quad \sum_{k=1}^{K} \lambda_k
- W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})
-
- s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}
-
- where :
-
- - :math:`\lambda_k` is the weight of k-th source domain
- - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
- - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to k-th source domain defined as in [p. 5, 27], its expected shape is `(n_k, C)` where `n_k` is the number of elements in the k-th source domain and `C` is the number of classes
- - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size C
- - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n`
- - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in [p. 5, 27], its expected shape is `(n_k, C)`
-
- The problem consist in solving a Wasserstein barycenter problem to estimate the proportions :math:`\mathbf{h}` in the target domain.
-
- The algorithm used for solving the problem is the Iterative Bregman projections algorithm
- with two sets of marginal constraints related to the unknown vector :math:`\mathbf{h}` and uniform tarhet distribution.
-
- Parameters
- ----------
- Xs : list of K np.ndarray(nsk,d)
- features of all source domains' samples
- Ys : list of K np.ndarray(nsk,)
- labels of all source domains' samples
- Xt : np.ndarray (nt,d)
- samples in the target domain
- reg : float
- Regularization term > 0
- metric : string, optional (default="sqeuclidean")
- The ground metric for the Wasserstein problem
- numItermax : int, optional
- Max number of iterations
- stopThr : float, optional
- Stop threshold on relative change in the barycenter (>0)
- log : bool, optional
- record log if True
- verbose : bool, optional (default=False)
- Controls the verbosity of the optimization algorithm
-
- Returns
- -------
- gamma : List of K (nsk x nt) ndarrays
- Optimal transportation matrices for the given parameters for each pair of source and target domains
- h : (C,) ndarray
- proportion estimation in the target domain
- log : dict
- log dictionary return only if log==True in parameters
-
-
- References
- ----------
-
- .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
- "Optimal transport for multi-source domain adaptation under target shift",
- International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
-
- '''
- nbclasses = len(np.unique(Ys[0]))
- nbdomains = len(Xs)
-
- # log dictionary
- if log:
- log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': []}
-
- K = []
- M = []
- D1 = []
- D2 = []
-
- # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
- for d in range(nbdomains):
- dom = {}
- nsk = Xs[d].shape[0] # get number of elements for this domain
- dom['nbelem'] = nsk
- classes = np.unique(Ys[d]) # get number of classes for this domain
-
- # format classes to start from 0 for convenience
- if np.min(classes) != 0:
- Ys[d] = Ys[d] - np.min(classes)
- classes = np.unique(Ys[d])
-
- # build the corresponding D_1 and D_2 matrices
- Dtmp1 = np.zeros((nbclasses, nsk))
- Dtmp2 = np.zeros((nbclasses, nsk))
-
- for c in classes:
- nbelemperclass = np.sum(Ys[d] == c)
- if nbelemperclass != 0:
- Dtmp1[int(c), Ys[d] == c] = 1.
- Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
- D1.append(Dtmp1)
- D2.append(Dtmp2)
-
- # build the cost matrix and the Gibbs kernel
- Mtmp = dist(Xs[d], Xt, metric=metric)
- Mtmp = Mtmp / np.median(Mtmp)
- M.append(Mtmp)
-
- Ktmp = np.empty(Mtmp.shape, dtype=Mtmp.dtype)
- np.divide(Mtmp, -reg, out=Ktmp)
- np.exp(Ktmp, out=Ktmp)
- K.append(Ktmp)
-
- # uniform target distribution
- a = unif(np.shape(Xt)[0])
-
- cpt = 0 # iterations count
- err = 1
- old_bary = np.ones((nbclasses))
-
- while (err > stopThr and cpt < numItermax):
-
- bary = np.zeros((nbclasses))
-
- # update coupling matrices for marginal constraints w.r.t. uniform target distribution
- for d in range(nbdomains):
- K[d] = projC(K[d], a)
- other = np.sum(K[d], axis=1)
- bary = bary + np.log(np.dot(D1[d], other)) / nbdomains
-
- bary = np.exp(bary)
-
- # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
- for d in range(nbdomains):
- new = np.dot(D2[d].T, bary)
- K[d] = projR(K[d], new)
-
- err = np.linalg.norm(bary - old_bary)
- cpt = cpt + 1
- old_bary = bary
-
- if log:
- log['err'].append(err)
-
- if verbose:
- if cpt % 200 == 0:
- print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
- print('{:5d}|{:8e}|'.format(cpt, err))
-
- bary = bary / np.sum(bary)
-
- if log:
- log['niter'] = cpt
- log['M'] = M
- log['D1'] = D1
- log['D2'] = D2
- return K, bary, log
- else:
- return K, bary
-
-
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, verbose=False,
log=False, **kwargs):
diff --git a/ot/da.py b/ot/da.py
index 0fdd3be..474c944 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -14,7 +14,7 @@ Domain adaptation with optimal transport
import numpy as np
import scipy.linalg as linalg
-from .bregman import sinkhorn, jcpot_barycenter
+from .bregman import sinkhorn
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, laplacian
from .utils import check_params, BaseEstimator
@@ -2121,181 +2121,4 @@ class UnbalancedSinkhornTransport(BaseTransport):
self.coupling_ = returned_
self.log_ = dict()
- return self
-
-
-class JCPOTTransport(BaseTransport):
-
- """Domain Adapatation OT method for multi-source target shift based on Wasserstein barycenter algorithm.
-
- Parameters
- ----------
- reg_e : float, optional (default=1)
- Entropic regularization parameter
- max_iter : int, float, optional (default=10)
- The minimum number of iteration before stopping the optimization
- algorithm if no it has not converged
- tol : float, optional (default=10e-9)
- Stop threshold on error (inner sinkhorn solver) (>0)
- verbose : bool, optional (default=False)
- Controls the verbosity of the optimization algorithm
- log : bool, optional (default=False)
- Controls the logs of the optimization algorithm
- metric : string, optional (default="sqeuclidean")
- The ground metric for the Wasserstein problem
- norm : string, optional (default=None)
- If given, normalize the ground metric to avoid numerical errors that
- can occur with large metric values.
- distribution_estimation : callable, optional (defaults to the uniform)
- The kind of distribution estimation to employ
- out_of_sample_map : string, optional (default="ferradans")
- The kind of out of sample mapping to apply to transport samples
- from a domain into another one. Currently the only possible option is
- "ferradans" which uses the method proposed in [6].
-
- Attributes
- ----------
- coupling_ : list of array-like objects, shape K x (n_source_samples, n_target_samples)
- A set of optimal couplings between each source domain and the target domain
- proportions_ : array-like, shape (n_classes,)
- Estimated class proportions in the target domain
- log_ : dictionary
- The dictionary of log, empty dic if parameter log is not True
-
- References
- ----------
-
- .. [1] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
- "Optimal transport for multi-source domain adaptation under target shift",
- International Conference on Artificial Intelligence and Statistics (AISTATS),
- vol. 89, p.849-858, 2019.
-
- """
-
- def __init__(self, reg_e=.1, max_iter=10,
- tol=10e-9, verbose=False, log=False,
- metric="sqeuclidean",
- out_of_sample_map='ferradans'):
- self.reg_e = reg_e
- self.max_iter = max_iter
- self.tol = tol
- self.verbose = verbose
- self.log = log
- self.metric = metric
- self.out_of_sample_map = out_of_sample_map
-
- def fit(self, Xs, ys=None, Xt=None, yt=None):
- """Building coupling matrices from a list of source and target sets of samples
- (Xs, ys) and (Xt, yt)
-
- Parameters
- ----------
- Xs : list of K array-like objects, shape K x (nk_source_samples, n_features)
- A list of the training input samples.
- ys : list of K array-like objects, shape K x (nk_source_samples,)
- A list of the class labels
- Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
- yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
-
- Warning: Note that, due to this convention -1 cannot be used as a
- class label
-
- Returns
- -------
- self : object
- Returns self.
- """
-
- # check the necessary inputs parameters are here
- if check_params(Xs=Xs, Xt=Xt, ys=ys):
-
- self.xs_ = Xs
- self.xt_ = Xt
-
- returned_ = jcpot_barycenter(Xs=Xs, Ys=ys, Xt=Xt, reg=self.reg_e,
- metric=self.metric, distrinumItermax=self.max_iter, stopThr=self.tol,
- verbose=self.verbose, log=self.log)
-
- # deal with the value of log
- if self.log:
- self.coupling_, self.proportions_, self.log_ = returned_
- else:
- self.coupling_, self.proportions_ = returned_
- self.log_ = dict()
-
- return self
-
- def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
- """Transports source samples Xs onto target ones Xt
-
- Parameters
- ----------
- Xs : array-like, shape (n_source_samples, n_features)
- The training input samples.
- ys : array-like, shape (n_source_samples,)
- The class labels
- Xt : array-like, shape (n_target_samples, n_features)
- The training input samples.
- yt : array-like, shape (n_target_samples,)
- The class labels. If some target samples are unlabeled, fill the
- yt's elements with -1.
-
- Warning: Note that, due to this convention -1 cannot be used as a
- class label
- batch_size : int, optional (default=128)
- The batch size for out of sample inverse transform
- """
-
- transp_Xs = []
-
- # check the necessary inputs parameters are here
- if check_params(Xs=Xs):
-
- if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]):
-
- # perform standard barycentric mapping for each source domain
-
- for coupling in self.coupling_:
- transp = coupling / np.sum(coupling, 1)[:, None]
-
- # set nans to 0
- transp[~ np.isfinite(transp)] = 0
-
- # compute transported samples
- transp_Xs.append(np.dot(transp, self.xt_))
- else:
-
- # perform out of sample mapping
- indices = np.arange(Xs.shape[0])
- batch_ind = [
- indices[i:i + batch_size]
- for i in range(0, len(indices), batch_size)]
-
- transp_Xs = []
-
- for bi in batch_ind:
- transp_Xs_ = []
-
- # get the nearest neighbor in the sources domains
- xs = np.concatenate(self.xs_, axis=0)
- idx = np.argmin(dist(Xs[bi], xs), axis=1)
-
- # transport the source samples
- for coupling in self.coupling_:
- transp = coupling / np.sum(
- coupling, 1)[:, None]
- transp[~ np.isfinite(transp)] = 0
- transp_Xs_.append(np.dot(transp, self.xt_))
-
- transp_Xs_ = np.concatenate(transp_Xs_, axis=0)
-
- # define the transported points
- transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :]
- transp_Xs.append(transp_Xs_)
-
- transp_Xs = np.concatenate(transp_Xs, axis=0)
-
- return transp_Xs
+ return self \ No newline at end of file
diff --git a/test/test_da.py b/test/test_da.py
index 4eaf193..0e31f26 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -549,60 +549,6 @@ def test_linear_mapping_class():
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-def test_jcpot_transport_class():
- """test_jcpot_transport
- """
-
- ns1 = 150
- ns2 = 150
- nt = 200
-
- Xs1, ys1 = make_data_classif('3gauss', ns1)
- Xs2, ys2 = make_data_classif('3gauss', ns2)
-
- Xt, yt = make_data_classif('3gauss2', nt)
-
- Xs = [Xs1, Xs2]
- ys = [ys1, ys2]
-
- otda = ot.da.JCPOTTransport(reg_e=0.01, max_iter=1000, tol=1e-9, verbose=True, log=True)
-
- # test its computed
- otda.fit(Xs=Xs, ys=ys, Xt=Xt)
-
- assert hasattr(otda, "coupling_")
- assert hasattr(otda, "proportions_")
- assert hasattr(otda, "log_")
-
- # test dimensions of coupling
- for i, xs in enumerate(Xs):
- assert_equal(otda.coupling_[i].shape, ((xs.shape[0], Xt.shape[0])))
-
- # test all margin constraints
- mu_t = unif(nt)
-
- for i in range(len(Xs)):
- # test margin constraints w.r.t. uniform target weights for each coupling matrix
- assert_allclose(
- np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3)
-
- # test margin constraints w.r.t. modified source weights for each source domain
-
- assert_allclose(
- np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3,
- atol=1e-3)
-
- # test transform
- transp_Xs = otda.transform(Xs=Xs)
- [assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
-
- Xs_new, _ = make_data_classif('3gauss', ns1 + 1)
- transp_Xs_new = otda.transform(Xs_new)
-
- # check that the oos method is working
- assert_equal(transp_Xs_new.shape, Xs_new.shape)
-
-
def test_emd_laplace_class():
"""test_emd_laplace_transport
"""
@@ -654,4 +600,4 @@ def test_emd_laplace_class():
# test fit_transform
transp_Xs = otda.fit_transform(Xs=Xs, Xt=Xt)
- assert_equal(transp_Xs.shape, Xs.shape)
+ assert_equal(transp_Xs.shape, Xs.shape) \ No newline at end of file