From 0411ea22a96f9c22af30156b45c16ef39ffb520d Mon Sep 17 00:00:00 2001 From: RĂ©mi Flamary Date: Thu, 15 Dec 2022 09:28:01 +0100 Subject: [MRG] New API for OT solver (with pre-computed ground cost matrix) (#388) * new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort --- test/test_backend.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'test/test_backend.py') diff --git a/test/test_backend.py b/test/test_backend.py index 311c075..3628f61 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -274,6 +274,8 @@ def test_empty_backend(): nx.inv(M) with pytest.raises(NotImplementedError): nx.sqrtm(M) + with pytest.raises(NotImplementedError): + nx.kl_div(M, M) with pytest.raises(NotImplementedError): nx.isfinite(M) with pytest.raises(NotImplementedError): @@ -592,6 +594,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("matrix square root") + A = nx.kl_div(nx.abs(Mb), nx.abs(Mb) + 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append("Kullback-Leibler divergence") + A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0) A = nx.isfinite(A) lst_b.append(nx.to_numpy(A)) -- cgit v1.2.3