summaryrefslogtreecommitdiff
path: root/ot/plot.py
diff options
context:
space:
mode:
authorRĂ©mi Flamary <remi.flamary@gmail.com>2022-03-24 14:13:25 +0100
committerGitHub <noreply@github.com>2022-03-24 14:13:25 +0100
commit82452e0f5f6dae05c7a1cc384e7a1fb62ae7e0d5 (patch)
tree051871e3dc63e6bba1d0ecb1df6229796edd33bb /ot/plot.py
parent767171593f2a98a26b9a39bf110a45085e3b982e (diff)
[MRG] Add factored coupling (#358)
* add gfactored ot * pep8 and add doc * add exmaple for factotred OT * final number of PR * correct test on backends * remove useless loss * better tests
Diffstat (limited to 'ot/plot.py')
-rw-r--r--ot/plot.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/ot/plot.py b/ot/plot.py
index 2208c90..8ade2eb 100644
--- a/ot/plot.py
+++ b/ot/plot.py
@@ -85,8 +85,13 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs):
if ('color' not in kwargs) and ('c' not in kwargs):
kwargs['color'] = 'k'
mx = G.max()
+ if 'alpha' in kwargs:
+ scale = kwargs['alpha']
+ del kwargs['alpha']
+ else:
+ scale = 1
for i in range(xs.shape[0]):
for j in range(xt.shape[0]):
if G[i, j] / mx > thr:
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
- alpha=G[i, j] / mx, **kwargs)
+ alpha=G[i, j] / mx * scale, **kwargs)