From 3cc99e6590fa87ae8705fc93315590b27bf84efc Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Fri, 7 Apr 2017 14:12:37 +0200 Subject: better dicumentation --- ot/dr.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 6 deletions(-) (limited to 'ot/dr.py') diff --git a/ot/dr.py b/ot/dr.py index 3732d81..3965149 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -1,16 +1,15 @@ # -*- coding: utf-8 -*- """ -Domain adaptation with optimal transport +Dimension reduction with optimal transport """ - import autograd.numpy as np from pymanopt.manifolds import Stiefel from pymanopt import Problem from pymanopt.solvers import SteepestDescent, TrustRegions def dist(x1,x2): - """ Compute squared euclidena distance between samples + """ Compute squared euclidean distance between samples """ x1p2=np.sum(np.square(x1),1) x2p2=np.sum(np.square(x2),1) @@ -40,18 +39,66 @@ def split_classes(X,y): def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0): """ - Wasserstein Discriminant Analysis + Wasserstein Discriminant Analysis [11]_ The function solves the following optimization problem: .. math:: - P = arg\min_P \frac{\sum_i W(PX^i,PX^i)}{\sum_{i,j\neq i} W(PX^i,PX^j)} + P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)} where : - + + - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold - :math:`W` is entropic regularized Wasserstein distances - :math:`X^i` are samples in the dataset corresponding to class i + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) + samples in the target domain + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshol on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns x nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.sinkhorn(a,b,M,1) + array([[ 0.36552929, 0.13447071], + [ 0.13447071, 0.36552929]]) + + + References + ---------- + + .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. + + + + """ mx=np.mean(X) -- cgit v1.2.3