diff options
authorRémi Flamary <>2017-04-07 14:00:10 +0200
committerRémi Flamary <>2017-04-07 14:00:10 +0200
commit140baad30ab822deeccd6f1cb4fedc3136370ab4 (patch)
parentb30a380fe5a5115cd0a5596f08903259c077f12c (diff)
add WDA
2 files changed, 164 insertions, 0 deletions
diff --git a/examples/ b/examples/
new file mode 100644
index 0000000..8d24bdc
--- /dev/null
+++ b/examples/
@@ -0,0 +1,63 @@
+# -*- coding: utf-8 -*-
+1D optimal transport
+@author: rflamary
+import numpy as np
+import matplotlib.pylab as pl
+import ot
+from ot.datasets import get_1D_gauss as gauss
+from ot.dr import wda
+#%% parameters
+n=1000 # nb samples in source and target datasets
+#%% plot samples
+pl.scatter(xt[:,0],xt[:,1],c=ys,marker='+',label='Source samples')
+pl.title('Discriminant dimensions')
+#%% plot distributions and loss matrix
+P,proj = wda(xs,ys,p,reg,k,maxiter=maxiter)
+#%% plot samples
+pl.scatter(xsp[:,0],xsp[:,1],c=ys,marker='+',label='Projected samples')
+pl.title('Projected training samples')
+pl.scatter(xtp[:,0],xtp[:,1],c=ys,marker='+',label='Projected samples')
+pl.title('Projected test samples')
diff --git a/ot/ b/ot/
new file mode 100644
index 0000000..3732d81
--- /dev/null
+++ b/ot/
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+Domain adaptation with optimal transport
+import autograd.numpy as np
+from pymanopt.manifolds import Stiefel
+from pymanopt import Problem
+from pymanopt.solvers import SteepestDescent, TrustRegions
+def dist(x1,x2):
+ """ Compute squared euclidena distance between samples
+ """
+ x1p2=np.sum(np.square(x1),1)
+ x2p2=np.sum(np.square(x2),1)
+ return x1p2.reshape((-1,1))+x2p2.reshape((1,-1))-2*,x2.T)
+def sinkhorn(w1,w2,M,reg,k):
+ """
+ Simple solver for Sinkhorn algorithm with fixed number of iteration
+ """
+ K=np.exp(-M/reg)
+ ui=np.ones((M.shape[0],))
+ vi=np.ones((M.shape[1],))
+ for i in range(k):
+ vi=w2/(,ui))
+ ui=w1/(,vi))
+ G=ui.reshape((M.shape[0],1))*K*vi.reshape((1,M.shape[1]))
+ return G
+def split_classes(X,y):
+ """
+ split samples in X by classes in y
+ """
+ lstsclass=np.unique(y)
+ return [X[y==i,:].astype(np.float32) for i in lstsclass]
+def wda(X,y,p=2,reg=1,k=10,solver = None,maxiter=100,verbose=0):
+ """
+ Wasserstein Discriminant Analysis
+ The function solves the following optimization problem:
+ .. math::
+ P = arg\min_P \frac{\sum_i W(PX^i,PX^i)}{\sum_{i,j\neq i} W(PX^i,PX^j)}
+ where :
+ - :math:`W` is entropic regularized Wasserstein distances
+ - :math:`X^i` are samples in the dataset corresponding to class i
+ """
+ mx=np.mean(X)
+ X-=mx.reshape((1,-1))
+ # data split between classes
+ d=X.shape[1]
+ xc=split_classes(X,y)
+ # compute uniform weighs
+ wc=[np.ones((x.shape[0]),dtype=np.float32)/x.shape[0] for x in xc]
+ def cost(P):
+ # wda loss
+ loss_b=0
+ loss_w=0
+ for i,xi in enumerate(xc):
+ for j,xj in enumerate(xc[i:]):
+ M=dist(xi,xj)
+ G=sinkhorn(wc[i],wc[j+i],M,reg,k)
+ if j==0:
+ loss_w+=np.sum(G*M)
+ else:
+ loss_b+=np.sum(G*M)
+ # loss inversed because minimization
+ return loss_w/loss_b
+ # declare manifold and problem
+ manifold = Stiefel(d, p)
+ problem = Problem(manifold=manifold, cost=cost)
+ # declare solver and solve
+ if solver is None:
+ solver= SteepestDescent(maxiter=maxiter,logverbosity=verbose)
+ elif solver in ['tr','TrustRegions']:
+ solver= TrustRegions(maxiter=maxiter,logverbosity=verbose)
+ Popt = solver.solve(problem)
+ def proj(X):
+ return (X-mx.reshape((1,-1))).dot(Popt)
+ return Popt, proj