diff options
author | tgnassou <66993815+tgnassou@users.noreply.github.com> | 2023-02-28 08:35:12 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-28 08:35:12 +0100 |
commit | a6d5d75c6ca584ab9736b528810b3595f2571d82 (patch) | |
tree | 1fb295f86cc055222cc223ae00d05a674bb3a46c | |
parent | a313e21f223af16cf21d3b7dd01bd0c6345d574c (diff) |
[MRG] Add method argument to sinkhorn Transport (#440)
* add method argument to sinkhron transport'
* update release
-rw-r--r-- | RELEASES.md | 1 | ||||
-rw-r--r-- | ot/da.py | 5 |
2 files changed, 4 insertions, 2 deletions
diff --git a/RELEASES.md b/RELEASES.md index 292d1df..e251c30 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,7 @@ - New API for OT solver using function `ot.solve` (PR #388) - Backend version of `ot.partial` and `ot.smooth` (PR #388) - Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437) +- Add parameters method in `ot.da.SinkhornTransport` (PR #440) #### Closed issues @@ -1417,12 +1417,13 @@ class SinkhornTransport(BaseTransport): Sciences, 7(3), 1853-1882. """ - def __init__(self, reg_e=1., max_iter=1000, + def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000, tol=10e-9, verbose=False, log=False, metric="sqeuclidean", norm=None, distribution_estimation=distribution_estimation_uniform, out_of_sample_map='ferradans', limit_max=np.infty): self.reg_e = reg_e + self.method = method self.max_iter = max_iter self.tol = tol self.verbose = verbose @@ -1463,7 +1464,7 @@ class SinkhornTransport(BaseTransport): # coupling estimation returned_ = sinkhorn( a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e, - numItermax=self.max_iter, stopThr=self.tol, + method=self.method, numItermax=self.max_iter, stopThr=self.tol, verbose=self.verbose, log=self.log) # deal with the value of log |