diff options
author | Rémi Flamary <remi.flamary@gmail.com> | 2020-05-05 07:52:16 +0200 |
---|---|---|
committer | Rémi Flamary <remi.flamary@gmail.com> | 2020-05-05 07:52:16 +0200 |
commit | 217ffbb73e15490935172f129e0ee51449f11bb6 (patch) | |
tree | 17035d9b6f3a19c4916b60b7d642a6cbd06ca26b /examples/others | |
parent | 0a1e8cd6d63b354fe7dfe0439c08417ecf988121 (diff) |
cleanup WDA example with proper seeds
Diffstat (limited to 'examples/others')
-rw-r--r-- | examples/others/plot_WDA.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/examples/others/plot_WDA.py b/examples/others/plot_WDA.py index 5e17433..009d902 100644 --- a/examples/others/plot_WDA.py +++ b/examples/others/plot_WDA.py @@ -33,6 +33,8 @@ from ot.dr import wda, fda n = 1000 # nb samples in source and target datasets nz = 0.2 +np.random.RandomState(1) + # generate circle dataset t = np.random.rand(n) * 2 * np.pi ys = np.floor((np.arange(n) * 1.0 / n * 3)) + 1 @@ -88,7 +90,11 @@ reg = 1e0 k = 10 maxiter = 100 -Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter) +P0 = np.random.randn(xs.shape[1], p) + +P0 /= np.sqrt(np.sum(P0**2, 0, keepdims=True)) + +Pwda, projwda = wda(xs, ys, p, reg, k, maxiter=maxiter, P0=P0) ############################################################################## |