summaryrefslogtreecommitdiff
path: root/ot/gpu/da.py
diff options
context:
space:
mode:
Diffstat (limited to 'ot/gpu/da.py')
-rw-r--r--ot/gpu/da.py144
1 files changed, 144 insertions, 0 deletions
diff --git a/ot/gpu/da.py b/ot/gpu/da.py
new file mode 100644
index 0000000..4a98038
--- /dev/null
+++ b/ot/gpu/da.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+"""
+Domain adaptation with optimal transport with GPU implementation
+"""
+
+# Author: Remi Flamary <remi.flamary@unice.fr>
+# Nicolas Courty <ncourty@irisa.fr>
+# Michael Perrot <michael.perrot@univ-st-etienne.fr>
+# Leo Gautheron <https://github.com/aje>
+#
+# License: MIT License
+
+
+import cupy as np # np used for matrix computation
+import cupy as cp # cp used for cupy specific operations
+import numpy as npp
+from . import utils
+
+from .bregman import sinkhorn
+
+
+def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
+ numInnerItermax=200, stopInnerThr=1e-9, verbose=False,
+ log=False, to_numpy=True):
+ """
+ Solve the entropic regularization optimal transport problem with nonconvex
+ group lasso regularization on GPU
+
+ If the input matrix are in numpy format, they will be uploaded to the
+ GPU first which can incur significant time overhead.
+
+
+ The function solves the following optimization problem:
+
+ .. math::
+ \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega_e(\gamma)
+ + \eta \Omega_g(\gamma)
+
+ s.t. \gamma 1 = a
+
+ \gamma^T 1= b
+
+ \gamma\geq 0
+ where :
+
+ - M is the (ns,nt) metric cost matrix
+ - :math:`\Omega_e` is the entropic regularization term
+ :math:`\Omega_e(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
+ - :math:`\Omega_g` is the group lasso regulaization term
+ :math:`\Omega_g(\gamma)=\sum_{i,c} \|\gamma_{i,\mathcal{I}_c}\|^{1/2}_1`
+ where :math:`\mathcal{I}_c` are the index of samples from class c
+ in the source domain.
+ - a and b are source and target weights (sum to 1)
+
+ The algorithm used for solving the problem is the generalised conditional
+ gradient as proposed in [5]_ [7]_
+
+
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ labels_a : np.ndarray (ns,)
+ labels of samples in the source domain
+ b : np.ndarray (nt,)
+ samples weights in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term for entropic regularization >0
+ eta : float, optional
+ Regularization term for group lasso regularization >0
+ numItermax : int, optional
+ Max number of iterations
+ numInnerItermax : int, optional
+ Max number of iterations (inner sinkhorn solver)
+ stopInnerThr : float, optional
+ Stop threshold on error (inner sinkhorn solver) (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+ to_numpy : boolean, optional (default True)
+ If true convert back the GPU array result to numpy format.
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+
+ References
+ ----------
+
+ .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
+ "Optimal Transport for Domain Adaptation," in IEEE
+ Transactions on Pattern Analysis and Machine Intelligence ,
+ vol.PP, no.99, pp.1-1
+ .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
+ Generalized conditional gradient: analysis of convergence
+ and applications. arXiv preprint arXiv:1510.06567.
+
+ See Also
+ --------
+ ot.lp.emd : Unregularized OT
+ ot.bregman.sinkhorn : Entropic regularized OT
+ ot.optim.cg : General regularized OT
+
+ """
+
+ a, labels_a, b, M = utils.to_gpu(a, labels_a, b, M)
+
+ p = 0.5
+ epsilon = 1e-3
+
+ indices_labels = []
+ labels_a2 = cp.asnumpy(labels_a)
+ classes = npp.unique(labels_a2)
+ for c in classes:
+ idxc, = utils.to_gpu(npp.where(labels_a2 == c))
+ indices_labels.append(idxc)
+
+ W = np.zeros(M.shape)
+
+ for cpt in range(numItermax):
+ Mreg = M + eta * W
+ transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
+ stopThr=stopInnerThr, to_numpy=False)
+ # the transport has been computed. Check if classes are really
+ # separated
+ W = np.ones(M.shape)
+ for (i, c) in enumerate(classes):
+
+ majs = np.sum(transp[indices_labels[i]], axis=0)
+ majs = p * ((majs + epsilon)**(p - 1))
+ W[indices_labels[i]] = majs
+
+ if to_numpy:
+ return utils.to_np(transp)
+ else:
+ return transp