summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoreloitanguy <69361683+eloitanguy@users.noreply.github.com>2022-05-11 08:57:54 +0200
committerGitHub <noreply@github.com>2022-05-11 08:57:54 +0200
commitd6bf10d8502b1c69f58f009b16634a110053eca1 (patch)
tree8d74efb46fa79063f7c2285f1d99c41b5b2b9ac3
parentc1ccfc45350f8db3fa78d91b84eb4286bcf36e69 (diff)
[WIP] Graphical tweaks for GWB + fixed seed method for the partial gromov test (#376)
* GWB first solver version * tests + example for gwb (untested) + free_bar doc fix * improved doc, fixed minor bugs, better example visu * minor doc + visu fixes * plot GWB pep8 fix * fixed partial gromov test reproductibility * added an animation for the GWB visu * added PR num * minor doc fixes + better gwb logo * GWB graphical tweaks + better seed method for partial gromov test * fixed PR number * refixed seed issue * seed fix fix fix Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md3
-rw-r--r--examples/barycenters/plot_generalized_free_support_barycenter.py11
-rwxr-xr-xtest/test_partial.py10
3 files changed, 13 insertions, 11 deletions
diff --git a/RELEASES.md b/RELEASES.md
index c06721f..76385d6 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -4,14 +4,13 @@
#### New features
-- Added Generalized Wasserstein Barycenter solver + example (PR #372)
+- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
#### Closed issues
- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
(Issue #371, PR #373)
-
## 0.8.2
This releases introduces several new notable features. The less important
diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py
index 9af1953..e685ec7 100644
--- a/examples/barycenters/plot_generalized_free_support_barycenter.py
+++ b/examples/barycenters/plot_generalized_free_support_barycenter.py
@@ -33,8 +33,8 @@ import matplotlib.animation as animation
# Input measures
sub_sample_factor = 8
I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
-I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
-I3 = pl.imread('../../data/heart.png').astype(np.float64)[::sub_sample_factor, ::sub_sample_factor, 2]
+I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
+I3 = pl.imread('../../data/heart.png').astype(np.float64)[::-sub_sample_factor, ::sub_sample_factor, 2]
sz = I1.shape[0]
UU, VV = np.meshgrid(np.arange(sz), np.arange(sz))
@@ -145,8 +145,11 @@ def _init():
def _update_plot(i):
- ax.view_init(elev=i, azim=4 * i)
+ if i < 45:
+ ax.view_init(elev=0, azim=4 * i)
+ else:
+ ax.view_init(elev=i - 45, azim=4 * i)
return fig,
-ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=90, interval=50, blit=True, repeat_delay=2000)
+ani = animation.FuncAnimation(fig, _update_plot, init_func=_init, frames=136, interval=50, blit=True, repeat_delay=2000)
diff --git a/test/test_partial.py b/test/test_partial.py
index e07377b..33fc259 100755
--- a/test/test_partial.py
+++ b/test/test_partial.py
@@ -137,7 +137,7 @@ def test_partial_wasserstein():
def test_partial_gromov_wasserstein():
- np.random.seed(42)
+ rng = np.random.RandomState(seed=42)
n_samples = 20 # nb samples
n_noise = 10 # nb of samples (noise)
@@ -150,11 +150,11 @@ def test_partial_gromov_wasserstein():
mu_t = np.array([0, 0, 0])
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
- xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s)
- xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) + 1) * 4)), axis=0)
+ xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, rng)
+ xs = np.concatenate((xs, ((rng.rand(n_noise, 2) + 1) * 4)), axis=0)
P = sp.linalg.sqrtm(cov_t)
- xt = np.random.randn(n_samples, 3).dot(P) + mu_t
- xt = np.concatenate((xt, ((np.random.rand(n_noise, 3) + 1) * 10)), axis=0)
+ xt = rng.randn(n_samples, 3).dot(P) + mu_t
+ xt = np.concatenate((xt, ((rng.rand(n_noise, 3) + 1) * 10)), axis=0)
xt2 = xs[::-1].copy()
C1 = ot.dist(xs, xs)