summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortgnassou <66993815+tgnassou@users.noreply.github.com>2023-02-28 08:35:12 +0100
committerGitHub <noreply@github.com>2023-02-28 08:35:12 +0100
commita6d5d75c6ca584ab9736b528810b3595f2571d82 (patch)
tree1fb295f86cc055222cc223ae00d05a674bb3a46c
parenta313e21f223af16cf21d3b7dd01bd0c6345d574c (diff)
[MRG] Add method argument to sinkhorn Transport (#440)
* add method argument to sinkhron transport' * update release
-rw-r--r--RELEASES.md1
-rw-r--r--ot/da.py5
2 files changed, 4 insertions, 2 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 292d1df..e251c30 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -14,6 +14,7 @@
- New API for OT solver using function `ot.solve` (PR #388)
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
+- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
#### Closed issues
diff --git a/ot/da.py b/ot/da.py
index 35e303b..5067a69 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -1417,12 +1417,13 @@ class SinkhornTransport(BaseTransport):
Sciences, 7(3), 1853-1882.
"""
- def __init__(self, reg_e=1., max_iter=1000,
+ def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000,
tol=10e-9, verbose=False, log=False,
metric="sqeuclidean", norm=None,
distribution_estimation=distribution_estimation_uniform,
out_of_sample_map='ferradans', limit_max=np.infty):
self.reg_e = reg_e
+ self.method = method
self.max_iter = max_iter
self.tol = tol
self.verbose = verbose
@@ -1463,7 +1464,7 @@ class SinkhornTransport(BaseTransport):
# coupling estimation
returned_ = sinkhorn(
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
- numItermax=self.max_iter, stopThr=self.tol,
+ method=self.method, numItermax=self.max_iter, stopThr=self.tol,
verbose=self.verbose, log=self.log)
# deal with the value of log