summaryrefslogtreecommitdiff
path: root/ot/dr.py
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-04-07 14:12:37 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-04-07 14:12:37 +0200
commit3cc99e6590fa87ae8705fc93315590b27bf84efc (patch)
tree9e0665d1d24431238b373f442c3ed44fb0c9497d /ot/dr.py
parent140baad30ab822deeccd6f1cb4fedc3136370ab4 (diff)
better dicumentation
Diffstat (limited to 'ot/dr.py')
-rw-r--r--ot/dr.py59
1 files changed, 53 insertions, 6 deletions
diff --git a/ot/dr.py b/ot/dr.py
index 3732d81..3965149 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -1,16 +1,15 @@
# -*- coding: utf-8 -*-
"""
-Domain adaptation with optimal transport
+Dimension reduction with optimal transport
"""
-
import autograd.numpy as np
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1,x2):
- """ Compute squared euclidena distance between samples
+ """ Compute squared euclidean distance between samples
"""
x1p2=np.sum(np.square(x1),1)
x2p2=np.sum(np.square(x2),1)
@@ -40,18 +39,66 @@ def split_classes(X,y):
def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
"""
- Wasserstein Discriminant Analysis
+ Wasserstein Discriminant Analysis [11]_
The function solves the following optimization problem:
.. math::
- P = arg\min_P \frac{\sum_i W(PX^i,PX^i)}{\sum_{i,j\neq i} W(PX^i,PX^j)}
+ P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
where :
-
+
+ - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
- :math:`W` is entropic regularized Wasserstein distances
- :math:`X^i` are samples in the dataset corresponding to class i
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
+
+
+
+
"""
mx=np.mean(X)