summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRémi Flamary <remi.flamary@gmail.com>2017-04-07 14:12:37 +0200
committerRémi Flamary <remi.flamary@gmail.com>2017-04-07 14:12:37 +0200
commit3cc99e6590fa87ae8705fc93315590b27bf84efc (patch)
tree9e0665d1d24431238b373f442c3ed44fb0c9497d
parent140baad30ab822deeccd6f1cb4fedc3136370ab4 (diff)
better dicumentation
-rw-r--r--README.md2
-rw-r--r--docs/source/conf.py2
-rw-r--r--ot/dr.py59
3 files changed, 56 insertions, 7 deletions
diff --git a/README.md b/README.md
index 4bd03f9..15cd599 100644
--- a/README.md
+++ b/README.md
@@ -105,3 +105,5 @@ This toolbox benefit a lot from open source research and we would like to thank
[9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
+
+[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. \ No newline at end of file
diff --git a/docs/source/conf.py b/docs/source/conf.py
index e0d4e0b..529d410 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -24,7 +24,7 @@ class Mock(MagicMock):
@classmethod
def __getattr__(cls, name):
return Mock()
-MOCK_MODULES = [ 'emd','ot.lp.emd']
+MOCK_MODULES = [ 'emd','ot.lp.emd_wrap']
sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
# !!!!
diff --git a/ot/dr.py b/ot/dr.py
index 3732d81..3965149 100644
--- a/ot/dr.py
+++ b/ot/dr.py
@@ -1,16 +1,15 @@
# -*- coding: utf-8 -*-
"""
-Domain adaptation with optimal transport
+Dimension reduction with optimal transport
"""
-
import autograd.numpy as np
from pymanopt.manifolds import Stiefel
from pymanopt import Problem
from pymanopt.solvers import SteepestDescent, TrustRegions
def dist(x1,x2):
- """ Compute squared euclidena distance between samples
+ """ Compute squared euclidean distance between samples
"""
x1p2=np.sum(np.square(x1),1)
x2p2=np.sum(np.square(x2),1)
@@ -40,18 +39,66 @@ def split_classes(X,y):
def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
"""
- Wasserstein Discriminant Analysis
+ Wasserstein Discriminant Analysis [11]_
The function solves the following optimization problem:
.. math::
- P = arg\min_P \frac{\sum_i W(PX^i,PX^i)}{\sum_{i,j\neq i} W(PX^i,PX^j)}
+ P = \\text{arg}\min_P \\frac{\\sum_i W(PX^i,PX^i)}{\\sum_{i,j\\neq i} W(PX^i,PX^j)}
where :
-
+
+ - :math:`P` is a linear projection operator in the Stiefel(p,d) manifold
- :math:`W` is entropic regularized Wasserstein distances
- :math:`X^i` are samples in the dataset corresponding to class i
+ Parameters
+ ----------
+ a : np.ndarray (ns,)
+ samples weights in the source domain
+ b : np.ndarray (nt,)
+ samples in the target domain
+ M : np.ndarray (ns,nt)
+ loss matrix
+ reg : float
+ Regularization term >0
+ numItermax : int, optional
+ Max number of iterations
+ stopThr : float, optional
+ Stop threshol on error (>0)
+ verbose : bool, optional
+ Print information along iterations
+ log : bool, optional
+ record log if True
+
+
+ Returns
+ -------
+ gamma : (ns x nt) ndarray
+ Optimal transportation matrix for the given parameters
+ log : dict
+ log dictionary return only if log==True in parameters
+
+ Examples
+ --------
+
+ >>> import ot
+ >>> a=[.5,.5]
+ >>> b=[.5,.5]
+ >>> M=[[0.,1.],[1.,0.]]
+ >>> ot.sinkhorn(a,b,M,1)
+ array([[ 0.36552929, 0.13447071],
+ [ 0.13447071, 0.36552929]])
+
+
+ References
+ ----------
+
+ .. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063.
+
+
+
+
"""
mx=np.mean(X)