summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-09-27 15:17:53 +0200
committerGitHub <noreply@github.com>2022-09-27 15:17:53 +0200
commite433775c2015eb85c2683b6955618c2836f001bc (patch)
tree2f0f950efeeed87cad27301a3852a6272550ad7f
parentb295ffccc95149c1d63e805e1ca6f027a4071e2a (diff)
[MRG] Crash when computing weightless Hamming distance & Doc build (#402)
* Bug solve * Releases.md updated * pep8 * attempt to solve docs building bug * releases.md
-rw-r--r--RELEASES.md2
-rw-r--r--examples/barycenters/plot_barycenter_1D.py4
-rw-r--r--examples/unbalanced-partial/plot_UOT_barycenter_1D.py4
-rw-r--r--ot/utils.py4
-rw-r--r--test/test_utils.py1
5 files changed, 10 insertions, 5 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 8b4f0de..1c7b7da 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -23,6 +23,8 @@ incomplete transport plan above a certain size (slightly above 46k, its square b
roughly 2^31) (PR #381)
- Error raised when mass mismatch in emd2 (PR #386)
- Fixed an issue where a pytorch example would throw an error if executed on a GPU (Issue #389, PR #391)
+- Added a work-around for scipy's bug, where you cannot compute the Hamming distance with a "None" weight attribute. (Issue #400, PR #402)
+- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
## 0.8.2
diff --git a/examples/barycenters/plot_barycenter_1D.py b/examples/barycenters/plot_barycenter_1D.py
index 2373e99..8096245 100644
--- a/examples/barycenters/plot_barycenter_1D.py
+++ b/examples/barycenters/plot_barycenter_1D.py
@@ -106,7 +106,7 @@ for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = plt.gcf().gca(projection='3d')
+ax = plt.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
@@ -128,7 +128,7 @@ for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = plt.gcf().gca(projection='3d')
+ax = plt.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
diff --git a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
index 931798b..8d227c0 100644
--- a/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
+++ b/examples/unbalanced-partial/plot_UOT_barycenter_1D.py
@@ -127,7 +127,7 @@ for i, z in enumerate(zs):
ys = B_l2[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = pl.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
poly.set_alpha(0.7)
@@ -149,7 +149,7 @@ for i, z in enumerate(zs):
ys = B_wass[:, i]
verts.append(list(zip(x, ys)))
-ax = pl.gcf().gca(projection='3d')
+ax = pl.gcf().add_subplot(projection='3d')
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
poly.set_alpha(0.7)
diff --git a/ot/utils.py b/ot/utils.py
index 57fb4a4..e3437da 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -234,7 +234,9 @@ def dist(x1, x2=None, metric='sqeuclidean', p=2, w=None):
else:
if isinstance(metric, str) and metric.endswith("minkowski"):
return cdist(x1, x2, metric=metric, p=p, w=w)
- return cdist(x1, x2, metric=metric, w=w)
+ if w is not None:
+ return cdist(x1, x2, metric=metric, w=w)
+ return cdist(x1, x2, metric=metric)
def dist0(n, method='lin_square'):
diff --git a/test/test_utils.py b/test/test_utils.py
index 3cfd295..19b6365 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -143,6 +143,7 @@ def test_dist():
for metric in metrics_w:
print(metric)
ot.dist(x, x, metric=metric, p=3, w=np.random.random((2, )))
+ ot.dist(x, x, metric=metric, p=3, w=None) # check that not having any weight does not cause issues
for metric in metrics:
print(metric)
ot.dist(x, x, metric=metric, p=3)