summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNathan Cassereau <84033440+ncassereau-idris@users.noreply.github.com>2022-03-24 10:53:47 +0100
committerGitHub <noreply@github.com>2022-03-24 10:53:47 +0100
commit767171593f2a98a26b9a39bf110a45085e3b982e (patch)
tree4eb4bcc657efc53a65c3fb4439bd0e0e106b6745
parent9b9d2221d257f40ea3eb58b279b30d69162d62bb (diff)
[MRG] Domain adaptation and unbalanced solvers with backend support (#343)
* First draft * Add matrix inverse and square root to backend * Eigen decomposition for older versions of pytorch (1.8.1 and older) * Corrected eigen decomposition for pytorch 1.8.1 and older * Spectral theorem is a thing * Optimization * small optimization * More functions converted * pep8 * remove a warning and prepare torch meshgrid for future torch release (which will change default indexing) * dots and pep8 * Meshgrid corrected for older version and prepared for future versions changes * New backend functions * Base transport * LinearTransport * All transport classes + pep8 * PR added to release file * Jcpot barycenter test * unbalanced with backend * pep8 * bug solve * test of domain adaptation with backends * solve bug for tic toc & macos * solving scipy deprecation warning * solving scipy deprecation warning attempt2 * solving scipy deprecation warning attempt3 * A warning is triggered when a float->int conversion is detected * bug solve * docs * release file updated * Better handling of float->int conversion in EMD * Corrected test for is_floating_point * docs * release file updated * cupy does not allow implicit cast * fromnumpy * added test * test da tf jax * test unbalanced with no provided histogram * using type_as argument in unif function correctly * pep8 * transport plan cast in emd changed behaviour, now trying to cast as histogram's dtype, defaulting to cost matrix Co-authored-by: RĂ©mi Flamary <remi.flamary@gmail.com>
-rw-r--r--RELEASES.md2
-rw-r--r--ot/backend.py304
-rw-r--r--ot/bregman.py17
-rw-r--r--ot/da.py382
-rw-r--r--ot/lp/__init__.py83
-rw-r--r--ot/optim.py11
-rw-r--r--ot/unbalanced.py302
-rw-r--r--ot/utils.py26
-rw-r--r--test/test_1d_solver.py28
-rw-r--r--test/test_backend.py66
-rw-r--r--test/test_bregman.py81
-rw-r--r--test/test_da.py307
-rw-r--r--test/test_gromov.py147
-rw-r--r--test/test_optim.py17
-rw-r--r--test/test_ot.py19
-rw-r--r--test/test_sliced.py32
-rw-r--r--test/test_unbalanced.py157
-rw-r--r--test/test_weak.py4
18 files changed, 1160 insertions, 825 deletions
diff --git a/RELEASES.md b/RELEASES.md
index 0f1f231..86b401a 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -11,6 +11,7 @@
of the regularization parameter (PR #336).
- Backend implementation for `ot.lp.free_support_barycenter` (PR #340).
- Add weak OT solver + example (PR #341).
+- Add backend support for Domain Adaptation and Unbalanced solvers (PR #343).
- Add (F)GW linear dictionary learning solvers + example (PR #319)
- Add links to related PR and Issues in the doc release page (PR #350)
@@ -19,6 +20,7 @@
- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
PR #338)
- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)
+- Warning when feeding integer cost matrix to EMD solver resulting in an integer transport plan (Issue #345, PR #343)
- Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA
tensors (Issue #351, PR #352)
diff --git a/ot/backend.py b/ot/backend.py
index 6e0bc3d..361ffba 100644
--- a/ot/backend.py
+++ b/ot/backend.py
@@ -87,7 +87,9 @@ Performance
# License: MIT License
import numpy as np
-import scipy.special as scipy
+import scipy
+import scipy.linalg
+import scipy.special as special
from scipy.sparse import issparse, coo_matrix, csr_matrix
import warnings
import time
@@ -102,7 +104,7 @@ except ImportError:
try:
import jax
import jax.numpy as jnp
- import jax.scipy.special as jscipy
+ import jax.scipy.special as jspecial
from jax.lib import xla_bridge
jax_type = jax.numpy.ndarray
except ImportError:
@@ -202,13 +204,29 @@ class Backend():
def __str__(self):
return self.__name__
- # convert to numpy
- def to_numpy(self, a):
+ # convert batch of tensors to numpy
+ def to_numpy(self, *arrays):
+ """Returns the numpy version of tensors"""
+ if len(arrays) == 1:
+ return self._to_numpy(arrays[0])
+ else:
+ return [self._to_numpy(array) for array in arrays]
+
+ # convert a tensor to numpy
+ def _to_numpy(self, a):
"""Returns the numpy version of a tensor"""
raise NotImplementedError()
- # convert from numpy
- def from_numpy(self, a, type_as=None):
+ # convert batch of arrays from numpy
+ def from_numpy(self, *arrays, type_as=None):
+ """Creates tensors cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
+ if len(arrays) == 1:
+ return self._from_numpy(arrays[0], type_as=type_as)
+ else:
+ return [self._from_numpy(array, type_as=type_as) for array in arrays]
+
+ # convert an array from numpy
+ def _from_numpy(self, a, type_as=None):
"""Creates a tensor cloning a numpy array, with the given precision (defaulting to input's precision) and the given device (in case of GPUs)"""
raise NotImplementedError()
@@ -536,6 +554,16 @@ class Backend():
"""
raise NotImplementedError()
+ def argmin(self, a, axis=None):
+ r"""
+ Returns the indices of the minimum values of a tensor along given dimensions.
+
+ This function follows the api from :any:`numpy.argmin`
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.argmin.html
+ """
+ raise NotImplementedError()
+
def mean(self, a, axis=None):
r"""
Computes the arithmetic mean of a tensor along given dimensions.
@@ -786,6 +814,72 @@ class Backend():
"""
raise NotImplementedError()
+ def solve(self, a, b):
+ r"""
+ Solves a linear matrix equation, or system of linear scalar equations.
+
+ This function follows the api from :any:`numpy.linalg.solve`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.solve.html
+ """
+ raise NotImplementedError()
+
+ def trace(self, a):
+ r"""
+ Returns the sum along diagonals of the array.
+
+ This function follows the api from :any:`numpy.trace`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.trace.html
+ """
+ raise NotImplementedError()
+
+ def inv(self, a):
+ r"""
+ Computes the inverse of a matrix.
+
+ This function follows the api from :any:`scipy.linalg.inv`.
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.inv.html
+ """
+ raise NotImplementedError()
+
+ def sqrtm(self, a):
+ r"""
+ Computes the matrix square root. Requires input to be definite positive.
+
+ This function follows the api from :any:`scipy.linalg.sqrtm`.
+
+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.sqrtm.html
+ """
+ raise NotImplementedError()
+
+ def isfinite(self, a):
+ r"""
+ Tests element-wise for finiteness (not infinity and not Not a Number).
+
+ This function follows the api from :any:`numpy.isfinite`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.isfinite.html
+ """
+ raise NotImplementedError()
+
+ def array_equal(self, a, b):
+ r"""
+ True if two arrays have the same shape and elements, False otherwise.
+
+ This function follows the api from :any:`numpy.array_equal`.
+
+ See: https://numpy.org/doc/stable/reference/generated/numpy.array_equal.html
+ """
+ raise NotImplementedError()
+
+ def is_floating_point(self, a):
+ r"""
+ Returns whether or not the input consists of floats
+ """
+ raise NotImplementedError()
+
class NumpyBackend(Backend):
"""
@@ -802,10 +896,10 @@ class NumpyBackend(Backend):
rng_ = np.random.RandomState()
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
if type_as is None:
return a
elif isinstance(a, float):
@@ -936,6 +1030,9 @@ class NumpyBackend(Backend):
def argmax(self, a, axis=None):
return np.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return np.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return np.mean(a, axis=axis)
@@ -955,7 +1052,7 @@ class NumpyBackend(Backend):
return np.unique(a)
def logsumexp(self, a, axis=None):
- return scipy.logsumexp(a, axis=axis)
+ return special.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return np.stack(arrays, axis)
@@ -1004,8 +1101,11 @@ class NumpyBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return np.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return np.where(condition)
+ else:
+ return np.where(condition, x, y)
def copy(self, a):
return a.copy()
@@ -1046,6 +1146,27 @@ class NumpyBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+ def solve(self, a, b):
+ return np.linalg.solve(a, b)
+
+ def trace(self, a):
+ return np.trace(a)
+
+ def inv(self, a):
+ return scipy.linalg.inv(a)
+
+ def sqrtm(self, a):
+ return scipy.linalg.sqrtm(a)
+
+ def isfinite(self, a):
+ return np.isfinite(a)
+
+ def array_equal(self, a, b):
+ return np.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class JaxBackend(Backend):
"""
@@ -1075,13 +1196,15 @@ class JaxBackend(Backend):
jax.device_put(jnp.array(1, dtype=jnp.float64), d)
]
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return np.array(a)
def _change_device(self, a, type_as):
return jax.device_put(a, type_as.device_buffer.device())
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return jnp.array(a)
else:
@@ -1216,6 +1339,9 @@ class JaxBackend(Backend):
def argmax(self, a, axis=None):
return jnp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return jnp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return jnp.mean(a, axis=axis)
@@ -1235,7 +1361,7 @@ class JaxBackend(Backend):
return jnp.unique(a)
def logsumexp(self, a, axis=None):
- return jscipy.logsumexp(a, axis=axis)
+ return jspecial.logsumexp(a, axis=axis)
def stack(self, arrays, axis=0):
return jnp.stack(arrays, axis)
@@ -1293,8 +1419,11 @@ class JaxBackend(Backend):
# Currently, JAX does not support sparse matrices
return a
- def where(self, condition, x, y):
- return jnp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return jnp.where(condition)
+ else:
+ return jnp.where(condition, x, y)
def copy(self, a):
# No need to copy, JAX arrays are immutable
@@ -1339,6 +1468,28 @@ class JaxBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+ def solve(self, a, b):
+ return jnp.linalg.solve(a, b)
+
+ def trace(self, a):
+ return jnp.trace(a)
+
+ def inv(self, a):
+ return jnp.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = jnp.linalg.eigh(a)
+ return (V * jnp.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return jnp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return jnp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class TorchBackend(Backend):
"""
@@ -1384,10 +1535,10 @@ class TorchBackend(Backend):
self.ValFunction = ValFunction
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a.cpu().detach().numpy()
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
a = np.array(a)
if type_as is None:
@@ -1564,6 +1715,9 @@ class TorchBackend(Backend):
def argmax(self, a, axis=None):
return torch.argmax(a, dim=axis)
+ def argmin(self, a, axis=None):
+ return torch.argmin(a, dim=axis)
+
def mean(self, a, axis=None):
if axis is not None:
return torch.mean(a, dim=axis)
@@ -1580,8 +1734,11 @@ class TorchBackend(Backend):
return torch.linspace(start, stop, num, dtype=torch.float64)
def meshgrid(self, a, b):
- X, Y = torch.meshgrid(a, b)
- return X.T, Y.T
+ try:
+ return torch.meshgrid(a, b, indexing="xy")
+ except TypeError:
+ X, Y = torch.meshgrid(a, b)
+ return X.T, Y.T
def diag(self, a, k=0):
return torch.diag(a, diagonal=k)
@@ -1659,8 +1816,11 @@ class TorchBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return torch.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return torch.where(condition)
+ else:
+ return torch.where(condition, x, y)
def copy(self, a):
return torch.clone(a)
@@ -1718,6 +1878,28 @@ class TorchBackend(Backend):
torch.cuda.empty_cache()
return results
+ def solve(self, a, b):
+ return torch.linalg.solve(a, b)
+
+ def trace(self, a):
+ return torch.trace(a)
+
+ def inv(self, a):
+ return torch.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = torch.linalg.eigh(a)
+ return (V * torch.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return torch.isfinite(a)
+
+ def array_equal(self, a, b):
+ return torch.equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.is_floating_point
+
class CupyBackend(Backend): # pragma: no cover
"""
@@ -1741,10 +1923,12 @@ class CupyBackend(Backend): # pragma: no cover
cp.array(1, dtype=cp.float64)
]
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return cp.asnumpy(a)
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if type_as is None:
return cp.asarray(a)
else:
@@ -1884,6 +2068,9 @@ class CupyBackend(Backend): # pragma: no cover
def argmax(self, a, axis=None):
return cp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return cp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return cp.mean(a, axis=axis)
@@ -1982,8 +2169,11 @@ class CupyBackend(Backend): # pragma: no cover
else:
return a
- def where(self, condition, x, y):
- return cp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return cp.where(condition)
+ else:
+ return cp.where(condition, x, y)
def copy(self, a):
return a.copy()
@@ -2035,6 +2225,28 @@ class CupyBackend(Backend): # pragma: no cover
pinned_mempool.free_all_blocks()
return results
+ def solve(self, a, b):
+ return cp.linalg.solve(a, b)
+
+ def trace(self, a):
+ return cp.trace(a)
+
+ def inv(self, a):
+ return cp.linalg.inv(a)
+
+ def sqrtm(self, a):
+ L, V = cp.linalg.eigh(a)
+ return (V * self.sqrt(L)[None, :]) @ V.T
+
+ def isfinite(self, a):
+ return cp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return cp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.kind == "f"
+
class TensorflowBackend(Backend):
@@ -2060,13 +2272,16 @@ class TensorflowBackend(Backend):
"To use TensorflowBackend, you need to activate the tensorflow "
"numpy API. You can activate it by running: \n"
"from tensorflow.python.ops.numpy_ops import np_config\n"
- "np_config.enable_numpy_behavior()"
+ "np_config.enable_numpy_behavior()",
+ stacklevel=2
)
- def to_numpy(self, a):
+ def _to_numpy(self, a):
return a.numpy()
- def from_numpy(self, a, type_as=None):
+ def _from_numpy(self, a, type_as=None):
+ if isinstance(a, float):
+ a = np.array(a)
if not isinstance(a, self.__type__):
if type_as is None:
return tf.convert_to_tensor(a)
@@ -2208,6 +2423,9 @@ class TensorflowBackend(Backend):
def argmax(self, a, axis=None):
return tnp.argmax(a, axis=axis)
+ def argmin(self, a, axis=None):
+ return tnp.argmin(a, axis=axis)
+
def mean(self, a, axis=None):
return tnp.mean(a, axis=axis)
@@ -2309,8 +2527,11 @@ class TensorflowBackend(Backend):
else:
return a
- def where(self, condition, x, y):
- return tnp.where(condition, x, y)
+ def where(self, condition, x=None, y=None):
+ if x is None and y is None:
+ return tnp.where(condition)
+ else:
+ return tnp.where(condition, x, y)
def copy(self, a):
return tf.identity(a)
@@ -2364,3 +2585,24 @@ class TensorflowBackend(Backend):
results[key] = (t1 - t0) / n_runs
return results
+
+ def solve(self, a, b):
+ return tf.linalg.solve(a, b)
+
+ def trace(self, a):
+ return tf.linalg.trace(a)
+
+ def inv(self, a):
+ return tf.linalg.inv(a)
+
+ def sqrtm(self, a):
+ return tf.linalg.sqrtm(a)
+
+ def isfinite(self, a):
+ return tnp.isfinite(a)
+
+ def array_equal(self, a, b):
+ return tnp.array_equal(a, b)
+
+ def is_floating_point(self, a):
+ return a.dtype.is_floating
diff --git a/ot/bregman.py b/ot/bregman.py
index fc20175..c06af2f 100644
--- a/ot/bregman.py
+++ b/ot/bregman.py
@@ -2525,8 +2525,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
# geometric interpolation
delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
- K0 = nx.dot(nx.diag(nx.dot(D.T, delta / inv_new)), K0)
-
+ K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0
err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
@@ -2656,16 +2655,16 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
classes = nx.unique(Ys[d])
# build the corresponding D_1 and D_2 matrices
- Dtmp1 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
- Dtmp2 = nx.zeros((nbclasses, nsk), type_as=Xs[0])
+ Dtmp1 = np.zeros((nbclasses, nsk))
+ Dtmp2 = np.zeros((nbclasses, nsk))
for c in classes:
- nbelemperclass = nx.sum(Ys[d] == c)
+ nbelemperclass = float(nx.sum(Ys[d] == c))
if nbelemperclass != 0:
- Dtmp1[int(c), Ys[d] == c] = 1.
- Dtmp2[int(c), Ys[d] == c] = 1. / (nbelemperclass)
- D1.append(Dtmp1)
- D2.append(Dtmp2)
+ Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1.
+ Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass)
+ D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0]))
+ D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0]))
# build the cost matrix and the Gibbs kernel
Mtmp = dist(Xs[d], Xt, metric=metric)
diff --git a/ot/da.py b/ot/da.py
index 841f31a..0b9737e 100644
--- a/ot/da.py
+++ b/ot/da.py
@@ -12,12 +12,12 @@ Domain adaptation with optimal transport
# License: MIT License
import numpy as np
-import scipy.linalg as linalg
+from .backend import get_backend
from .bregman import sinkhorn, jcpot_barycenter
from .lp import emd
from .utils import unif, dist, kernel, cost_normalization, label_normalization, laplacian, dots
-from .utils import check_params, BaseEstimator
+from .utils import list_to_array, check_params, BaseEstimator
from .unbalanced import sinkhorn_unbalanced
from .optim import cg
from .optim import gcg
@@ -60,13 +60,13 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Parameters
----------
- a : np.ndarray (ns,)
+ a : array-like (ns,)
samples weights in the source domain
- labels_a : np.ndarray (ns,)
+ labels_a : array-like (ns,)
labels of samples in the source domain
- b : np.ndarray (nt,)
+ b : array-like (nt,)
samples weights in the target domain
- M : np.ndarray (ns,nt)
+ M : array-like (ns,nt)
loss matrix
reg : float
Regularization term for entropic regularization >0
@@ -86,7 +86,7 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -111,26 +111,28 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
ot.optim.cg : General regularized OT
"""
+ a, labels_a, b, M = list_to_array(a, labels_a, b, M)
+ nx = get_backend(a, labels_a, b, M)
+
p = 0.5
epsilon = 1e-3
indices_labels = []
- classes = np.unique(labels_a)
+ classes = nx.unique(labels_a)
for c in classes:
- idxc, = np.where(labels_a == c)
+ idxc, = nx.where(labels_a == c)
indices_labels.append(idxc)
- W = np.zeros(M.shape)
-
+ W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
Mreg = M + eta * W
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
- W = np.ones(M.shape)
+ W = nx.ones(M.shape, type_as=M)
for (i, c) in enumerate(classes):
- majs = np.sum(transp[indices_labels[i]], axis=0)
+ majs = nx.sum(transp[indices_labels[i]], axis=0)
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs
@@ -174,13 +176,13 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Parameters
----------
- a : np.ndarray (ns,)
+ a : array-like (ns,)
samples weights in the source domain
- labels_a : np.ndarray (ns,)
+ labels_a : array-like (ns,)
labels of samples in the source domain
- b : np.ndarray (nt,)
+ b : array-like (nt,)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : array-like (ns,nt)
loss matrix
reg : float
Regularization term for entropic regularization >0
@@ -200,7 +202,7 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -222,22 +224,25 @@ def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
ot.optim.gcg : Generalized conditional gradient for OT problems
"""
- lstlab = np.unique(labels_a)
+ a, labels_a, b, M = list_to_array(a, labels_a, b, M)
+ nx = get_backend(a, labels_a, b, M)
+
+ lstlab = nx.unique(labels_a)
def f(G):
res = 0
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
- res += np.linalg.norm(temp)
+ res += nx.norm(temp)
return res
def df(G):
- W = np.zeros(G.shape)
+ W = nx.zeros(G.shape, type_as=G)
for i in range(G.shape[1]):
for lab in lstlab:
temp = G[labels_a == lab, i]
- n = np.linalg.norm(temp)
+ n = nx.norm(temp)
if n:
W[labels_a == lab, i] = temp / n
return W
@@ -289,9 +294,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
Parameters
----------
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
mu : float,optional
Weight for the linear OT loss (>0)
@@ -315,9 +320,9 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
- L : (d, d) ndarray
+ L : (d, d) array-like
Linear mapping matrix ((:math:`d+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
@@ -336,13 +341,15 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
ot.optim.cg : General regularized OT
"""
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
ns, nt, d = xs.shape[0], xt.shape[0], xt.shape[1]
if bias:
- xs1 = np.hstack((xs, np.ones((ns, 1))))
- xstxs = xs1.T.dot(xs1)
- Id = np.eye(d + 1)
+ xs1 = nx.concatenate((xs, nx.ones((ns, 1), type_as=xs)), axis=1)
+ xstxs = nx.dot(xs1.T, xs1)
+ Id = nx.eye(d + 1, type_as=xs)
Id[-1] = 0
I0 = Id[:, :-1]
@@ -350,8 +357,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
return x[:-1, :]
else:
xs1 = xs
- xstxs = xs1.T.dot(xs1)
- Id = np.eye(d)
+ xstxs = nx.dot(xs1.T, xs1)
+ Id = nx.eye(d, type_as=xs)
I0 = Id
def sel(x):
@@ -360,7 +367,8 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
if log:
log = {'err': []}
- a, b = unif(ns), unif(nt)
+ a = unif(ns, type_as=xs)
+ b = unif(nt, type_as=xt)
M = dist(xs, xt) * ns
G = emd(a, b, M)
@@ -368,23 +376,26 @@ def joint_OT_mapping_linear(xs, xt, mu=1, eta=0.001, bias=False, verbose=False,
def loss(L, G):
"""Compute full loss"""
- return np.sum((xs1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \
- np.sum(G * M) + eta * np.sum(sel(L - I0) ** 2)
+ return (
+ nx.sum((nx.dot(xs1, L) - ns * nx.dot(G, xt)) ** 2)
+ + mu * nx.sum(G * M)
+ + eta * nx.sum(sel(L - I0) ** 2)
+ )
def solve_L(G):
""" solve L problem with fixed G (least square)"""
- xst = ns * G.dot(xt)
- return np.linalg.solve(xstxs + eta * Id, xs1.T.dot(xst) + eta * I0)
+ xst = ns * nx.dot(G, xt)
+ return nx.solve(xstxs + eta * Id, nx.dot(xs1.T, xst) + eta * I0)
def solve_G(L, G0):
"""Update G with CG algorithm"""
- xsi = xs1.dot(L)
+ xsi = nx.dot(xs1, L)
def f(G):
- return np.sum((xsi - ns * G.dot(xt)) ** 2)
+ return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2)
def df(G):
- return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+ return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T)
G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
numItermax=numInnerItermax, stopThr=stopInnerThr)
@@ -481,9 +492,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
Parameters
----------
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
mu : float,optional
Weight for the linear OT loss (>0)
@@ -513,9 +524,9 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
- L : (ns, d) ndarray
+ L : (ns, d) array-like
Nonlinear mapping matrix ((:math:`n_s+1`, `d`) if bias)
log : dict
log dictionary return only if log==True in parameters
@@ -534,15 +545,17 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
ot.optim.cg : General regularized OT
"""
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
ns, nt = xs.shape[0], xt.shape[0]
K = kernel(xs, xs, method=kerneltype, sigma=sigma)
if bias:
- K1 = np.hstack((K, np.ones((ns, 1))))
- Id = np.eye(ns + 1)
+ K1 = nx.concatenate((K, nx.ones((ns, 1), type_as=xs)), axis=1)
+ Id = nx.eye(ns + 1, type_as=xs)
Id[-1] = 0
- Kp = np.eye(ns + 1)
+ Kp = nx.eye(ns + 1, type_as=xs)
Kp[:ns, :ns] = K
# ls regu
@@ -550,12 +563,12 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
# Kreg=I
# RKHS regul
- K0 = K1.T.dot(K1) + eta * Kp
+ K0 = nx.dot(K1.T, K1) + eta * Kp
Kreg = Kp
else:
K1 = K
- Id = np.eye(ns)
+ Id = nx.eye(ns, type_as=xs)
# ls regul
# K0 = K1.T.dot(K1)+eta*I
@@ -568,7 +581,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
if log:
log = {'err': []}
- a, b = unif(ns), unif(nt)
+ a = unif(ns, type_as=xs)
+ b = unif(nt, type_as=xt)
M = dist(xs, xt) * ns
G = emd(a, b, M)
@@ -576,28 +590,31 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian',
def loss(L, G):
"""Compute full loss"""
- return np.sum((K1.dot(L) - ns * G.dot(xt)) ** 2) + mu * \
- np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L))
+ return (
+ nx.sum((nx.dot(K1, L) - ns * nx.dot(G, xt)) ** 2)
+ + mu * nx.sum(G * M)
+ + eta * nx.trace(dots(L.T, Kreg, L))
+ )
def solve_L_nobias(G):
""" solve L problem with fixed G (least square)"""
- xst = ns * G.dot(xt)
- return np.linalg.solve(K0, xst)
+ xst = ns * nx.dot(G, xt)
+ return nx.solve(K0, xst)
def solve_L_bias(G):
""" solve L problem with fixed G (least square)"""
- xst = ns * G.dot(xt)
- return np.linalg.solve(K0, K1.T.dot(xst))
+ xst = ns * nx.dot(G, xt)
+ return nx.solve(K0, nx.dot(K1.T, xst))
def solve_G(L, G0):
"""Update G with CG algorithm"""
- xsi = K1.dot(L)
+ xsi = nx.dot(K1, L)
def f(G):
- return np.sum((xsi - ns * G.dot(xt)) ** 2)
+ return nx.sum((xsi - ns * nx.dot(G, xt)) ** 2)
def df(G):
- return -2 * ns * (xsi - ns * G.dot(xt)).dot(xt.T)
+ return -2 * ns * nx.dot(xsi - ns * nx.dot(G, xt), xt.T)
G = cg(a, b, M, 1.0 / mu, f, df, G0=G0,
numItermax=numInnerItermax, stopThr=stopInnerThr)
@@ -681,15 +698,15 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
Parameters
----------
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
reg : float,optional
regularization added to the diagonals of covariances (>0)
- ws : np.ndarray (ns,1), optional
+ ws : array-like (ns,1), optional
weights for the source samples
- wt : np.ndarray (ns,1), optional
+ wt : array-like (ns,1), optional
weights for the target samples
bias: boolean, optional
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
@@ -699,9 +716,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
Returns
-------
- A : (d, d) ndarray
+ A : (d, d) array-like
Linear operator
- b : (1, d) ndarray
+ b : (1, d) array-like
bias
log : dict
log dictionary return only if log==True in parameters
@@ -719,36 +736,38 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
"""
+ xs, xt = list_to_array(xs, xt)
+ nx = get_backend(xs, xt)
d = xs.shape[1]
if bias:
- mxs = xs.mean(0, keepdims=True)
- mxt = xt.mean(0, keepdims=True)
+ mxs = nx.mean(xs, axis=0)[None, :]
+ mxt = nx.mean(xt, axis=0)[None, :]
xs = xs - mxs
xt = xt - mxt
else:
- mxs = np.zeros((1, d))
- mxt = np.zeros((1, d))
+ mxs = nx.zeros((1, d), type_as=xs)
+ mxt = nx.zeros((1, d), type_as=xs)
if ws is None:
- ws = np.ones((xs.shape[0], 1)) / xs.shape[0]
+ ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0]
if wt is None:
- wt = np.ones((xt.shape[0], 1)) / xt.shape[0]
+ wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0]
- Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d)
- Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d)
+ Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(d, type_as=xs)
+ Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(d, type_as=xt)
- Cs12 = linalg.sqrtm(Cs)
- Cs_12 = linalg.inv(Cs12)
+ Cs12 = nx.sqrtm(Cs)
+ Cs_12 = nx.inv(Cs12)
- M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12)))
+ M0 = nx.sqrtm(dots(Cs12, Ct, Cs12))
- A = Cs_12.dot(M0.dot(Cs_12))
+ A = dots(Cs_12, M0, Cs_12)
- b = mxt - mxs.dot(A)
+ b = mxt - nx.dot(mxs, A)
if log:
log = {}
@@ -798,15 +817,15 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
Parameters
----------
- a : np.ndarray (ns,)
+ a : array-like (ns,)
samples weights in the source domain
- b : np.ndarray (nt,)
+ b : array-like (nt,)
samples weights in the target domain
- xs : np.ndarray (ns,d)
+ xs : array-like (ns,d)
samples in the source domain
- xt : np.ndarray (nt,d)
+ xt : array-like (nt,d)
samples in the target domain
- M : np.ndarray (ns,nt)
+ M : array-like (ns,nt)
loss matrix
sim : string, optional
Type of similarity ('knn' or 'gauss') used to construct the Laplacian.
@@ -834,7 +853,7 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
Returns
-------
- gamma : (ns, nt) ndarray
+ gamma : (ns, nt) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
@@ -862,9 +881,12 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
raise ValueError(
'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__))
+ a, b, xs, xt, M = list_to_array(a, b, xs, xt, M)
+ nx = get_backend(a, b, xs, xt, M)
+
if sim == 'gauss':
if sim_param is None:
- sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
+ sim_param = 1 / (2 * (nx.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
sS = kernel(xs, xs, method=sim, sigma=sim_param)
sT = kernel(xt, xt, method=sim, sigma=sim_param)
@@ -874,9 +896,13 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
from sklearn.neighbors import kneighbors_graph
- sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray()
+ sS = nx.from_numpy(kneighbors_graph(
+ X=nx.to_numpy(xs), n_neighbors=int(sim_param)
+ ).toarray(), type_as=xs)
sS = (sS + sS.T) / 2
- sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray()
+ sT = nx.from_numpy(kneighbors_graph(
+ X=nx.to_numpy(xt), n_neighbors=int(sim_param)
+ ).toarray(), type_as=xt)
sT = (sT + sT.T) / 2
else:
raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim))
@@ -885,12 +911,14 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
lT = laplacian(sT)
def f(G):
- return alpha * np.trace(np.dot(xt.T, np.dot(G.T, np.dot(lS, np.dot(G, xt))))) \
- + (1 - alpha) * np.trace(np.dot(xs.T, np.dot(G, np.dot(lT, np.dot(G.T, xs)))))
+ return (
+ alpha * nx.trace(dots(xt.T, G.T, lS, G, xt))
+ + (1 - alpha) * nx.trace(dots(xs.T, G, lT, G.T, xs))
+ )
ls2 = lS + lS.T
lt2 = lT + lT.T
- xt2 = np.dot(xt, xt.T)
+ xt2 = nx.dot(xt, xt.T)
if reg == 'disp':
Cs = -eta * alpha / xs.shape[0] * dots(ls2, xs, xt.T)
@@ -898,8 +926,10 @@ def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, al
M = M + Cs + Ct
def df(G):
- return alpha * np.dot(ls2, np.dot(G, xt2))\
- + (1 - alpha) * np.dot(xs, np.dot(xs.T, np.dot(G, lt2)))
+ return (
+ alpha * dots(ls2, G, xt2)
+ + (1 - alpha) * dots(xs, xs.T, G, lt2)
+ )
return cg(a, b, M, reg=eta, f=f, df=df, G0=None, numItermax=numItermax, numItermaxEmd=numInnerItermax,
stopThr=stopThr, stopThr2=stopInnerThr, verbose=verbose, log=log)
@@ -919,7 +949,7 @@ def distribution_estimation_uniform(X):
The uniform distribution estimated from :math:`\mathbf{X}`
"""
- return unif(X.shape[0])
+ return unif(X.shape[0], type_as=X)
class BaseTransport(BaseEstimator):
@@ -973,6 +1003,7 @@ class BaseTransport(BaseEstimator):
self : object
Returns self.
"""
+ nx = self._get_backend(Xs, ys, Xt, yt)
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt):
@@ -984,14 +1015,14 @@ class BaseTransport(BaseEstimator):
if (ys is not None) and (yt is not None):
if self.limit_max != np.infty:
- self.limit_max = self.limit_max * np.max(self.cost_)
+ self.limit_max = self.limit_max * nx.max(self.cost_)
# assumes labeled source samples occupy the first rows
# and labeled target samples occupy the first columns
- classes = [c for c in np.unique(ys) if c != -1]
+ classes = [c for c in nx.unique(ys) if c != -1]
for c in classes:
- idx_s = np.where((ys != c) & (ys != -1))
- idx_t = np.where(yt == c)
+ idx_s = nx.where((ys != c) & (ys != -1))
+ idx_t = nx.where(yt == c)
# all the coefficients corresponding to a source sample
# and a target sample :
@@ -1062,23 +1093,24 @@ class BaseTransport(BaseEstimator):
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- if np.array_equal(self.xs_, Xs):
+ if nx.array_equal(self.xs_, Xs):
# perform standard barycentric mapping
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / nx.sum(self.coupling_, axis=1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute transported samples
- transp_Xs = np.dot(transp, self.xt_)
+ transp_Xs = nx.dot(transp, self.xt_)
else:
# perform out of sample mapping
- indices = np.arange(Xs.shape[0])
+ indices = nx.arange(Xs.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]
@@ -1087,20 +1119,20 @@ class BaseTransport(BaseEstimator):
for bi in batch_ind:
# get the nearest neighbor in the source domain
D0 = dist(Xs[bi], self.xs_)
- idx = np.argmin(D0, axis=1)
+ idx = nx.argmin(D0, axis=1)
# transport the source samples
- transp = self.coupling_ / np.sum(
- self.coupling_, 1)[:, None]
- transp[~ np.isfinite(transp)] = 0
- transp_Xs_ = np.dot(transp, self.xt_)
+ transp = self.coupling_ / nx.sum(
+ self.coupling_, axis=1)[:, None]
+ transp[~ nx.isfinite(transp)] = 0
+ transp_Xs_ = nx.dot(transp, self.xt_)
# define the transported points
transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.xs_[idx, :]
transp_Xs.append(transp_Xs_)
- transp_Xs = np.concatenate(transp_Xs, axis=0)
+ transp_Xs = nx.concatenate(transp_Xs, axis=0)
return transp_Xs
@@ -1127,26 +1159,27 @@ class BaseTransport(BaseEstimator):
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(ys=ys):
- ysTemp = label_normalization(np.copy(ys))
- classes = np.unique(ysTemp)
+ ysTemp = label_normalization(nx.copy(ys))
+ classes = nx.unique(ysTemp)
n = len(classes)
- D1 = np.zeros((n, len(ysTemp)))
+ D1 = nx.zeros((n, len(ysTemp)), type_as=self.coupling_)
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True)
+ transp = self.coupling_ / nx.sum(self.coupling_, axis=0)[None, :]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
for c in classes:
D1[int(c), ysTemp == c] = 1
# compute propagated labels
- transp_ys = np.dot(D1, transp)
+ transp_ys = nx.dot(D1, transp)
return transp_ys.T
@@ -1176,23 +1209,24 @@ class BaseTransport(BaseEstimator):
transp_Xt : array-like, shape (n_source_samples, n_features)
The transported target samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xt=Xt):
- if np.array_equal(self.xt_, Xt):
+ if nx.array_equal(self.xt_, Xt):
# perform standard barycentric mapping
- transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None]
+ transp_ = self.coupling_.T / nx.sum(self.coupling_, 0)[:, None]
# set nans to 0
- transp_[~ np.isfinite(transp_)] = 0
+ transp_[~ nx.isfinite(transp_)] = 0
# compute transported samples
- transp_Xt = np.dot(transp_, self.xs_)
+ transp_Xt = nx.dot(transp_, self.xs_)
else:
# perform out of sample mapping
- indices = np.arange(Xt.shape[0])
+ indices = nx.arange(Xt.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]
@@ -1200,20 +1234,20 @@ class BaseTransport(BaseEstimator):
transp_Xt = []
for bi in batch_ind:
D0 = dist(Xt[bi], self.xt_)
- idx = np.argmin(D0, axis=1)
+ idx = nx.argmin(D0, axis=1)
# transport the target samples
- transp_ = self.coupling_.T / np.sum(
+ transp_ = self.coupling_.T / nx.sum(
self.coupling_, 0)[:, None]
- transp_[~ np.isfinite(transp_)] = 0
- transp_Xt_ = np.dot(transp_, self.xs_)
+ transp_[~ nx.isfinite(transp_)] = 0
+ transp_Xt_ = nx.dot(transp_, self.xs_)
# define the transported points
transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.xt_[idx, :]
transp_Xt.append(transp_Xt_)
- transp_Xt = np.concatenate(transp_Xt, axis=0)
+ transp_Xt = nx.concatenate(transp_Xt, axis=0)
return transp_Xt
@@ -1230,26 +1264,27 @@ class BaseTransport(BaseEstimator):
transp_ys : array-like, shape (n_source_samples, nb_classes)
Estimated soft source labels.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(yt=yt):
- ytTemp = label_normalization(np.copy(yt))
- classes = np.unique(ytTemp)
+ ytTemp = label_normalization(nx.copy(yt))
+ classes = nx.unique(ytTemp)
n = len(classes)
- D1 = np.zeros((n, len(ytTemp)))
+ D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_)
# perform label propagation
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
for c in classes:
D1[int(c), ytTemp == c] = 1
# compute propagated samples
- transp_ys = np.dot(D1, transp.T)
+ transp_ys = nx.dot(D1, transp.T)
return transp_ys.T
@@ -1330,14 +1365,15 @@ class LinearTransport(BaseTransport):
self : object
Returns self.
"""
+ nx = self._get_backend(Xs, ys, Xt, yt)
self.mu_s = self.distribution_estimation(Xs)
self.mu_t = self.distribution_estimation(Xt)
# coupling estimation
returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg,
- ws=self.mu_s.reshape((-1, 1)),
- wt=self.mu_t.reshape((-1, 1)),
+ ws=nx.reshape(self.mu_s, (-1, 1)),
+ wt=nx.reshape(self.mu_t, (-1, 1)),
bias=self.bias, log=self.log)
# deal with the value of log
@@ -1348,8 +1384,8 @@ class LinearTransport(BaseTransport):
self.log_ = dict()
# re compute inverse mapping
- self.A1_ = linalg.inv(self.A_)
- self.B1_ = -self.B_.dot(self.A1_)
+ self.A1_ = nx.inv(self.A_)
+ self.B1_ = -nx.dot(self.B_, self.A1_)
return self
@@ -1378,10 +1414,11 @@ class LinearTransport(BaseTransport):
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- transp_Xs = Xs.dot(self.A_) + self.B_
+ transp_Xs = nx.dot(Xs, self.A_) + self.B_
return transp_Xs
@@ -1411,10 +1448,11 @@ class LinearTransport(BaseTransport):
transp_Xt : array-like, shape (n_source_samples, n_features)
The transported target samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xt=Xt):
- transp_Xt = Xt.dot(self.A1_) + self.B1_
+ transp_Xt = nx.dot(Xt, self.A1_) + self.B1_
return transp_Xt
@@ -2112,6 +2150,7 @@ class MappingTransport(BaseEstimator):
self : object
Returns self
"""
+ self._get_backend(Xs, ys, Xt, yt)
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt):
@@ -2158,19 +2197,20 @@ class MappingTransport(BaseEstimator):
transp_Xs : array-like, shape (n_source_samples, n_features)
The transport source samples.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- if np.array_equal(self.xs_, Xs):
+ if nx.array_equal(self.xs_, Xs):
# perform standard barycentric mapping
- transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
+ transp = self.coupling_ / nx.sum(self.coupling_, 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute transported samples
- transp_Xs = np.dot(transp, self.xt_)
+ transp_Xs = nx.dot(transp, self.xt_)
else:
if self.kernel == "gaussian":
K = kernel(Xs, self.xs_, method=self.kernel,
@@ -2178,8 +2218,10 @@ class MappingTransport(BaseEstimator):
elif self.kernel == "linear":
K = Xs
if self.bias:
- K = np.hstack((K, np.ones((Xs.shape[0], 1))))
- transp_Xs = K.dot(self.mapping_)
+ K = nx.concatenate(
+ [K, nx.ones((Xs.shape[0], 1), type_as=K)], axis=1
+ )
+ transp_Xs = nx.dot(K, self.mapping_)
return transp_Xs
@@ -2396,6 +2438,7 @@ class JCPOTTransport(BaseTransport):
self : object
Returns self.
"""
+ self._get_backend(*Xs, *ys, Xt, yt)
# check the necessary inputs parameters are here
if check_params(Xs=Xs, Xt=Xt, ys=ys):
@@ -2438,28 +2481,29 @@ class JCPOTTransport(BaseTransport):
batch_size : int, optional (default=128)
The batch size for out of sample inverse transform
"""
+ nx = self.nx
transp_Xs = []
# check the necessary inputs parameters are here
if check_params(Xs=Xs):
- if all([np.allclose(x, y) for x, y in zip(self.xs_, Xs)]):
+ if all([nx.allclose(x, y) for x, y in zip(self.xs_, Xs)]):
# perform standard barycentric mapping for each source domain
for coupling in self.coupling_:
- transp = coupling / np.sum(coupling, 1)[:, None]
+ transp = coupling / nx.sum(coupling, 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute transported samples
- transp_Xs.append(np.dot(transp, self.xt_))
+ transp_Xs.append(nx.dot(transp, self.xt_))
else:
# perform out of sample mapping
- indices = np.arange(Xs.shape[0])
+ indices = nx.arange(Xs.shape[0])
batch_ind = [
indices[i:i + batch_size]
for i in range(0, len(indices), batch_size)]
@@ -2470,23 +2514,22 @@ class JCPOTTransport(BaseTransport):
transp_Xs_ = []
# get the nearest neighbor in the sources domains
- xs = np.concatenate(self.xs_, axis=0)
- idx = np.argmin(dist(Xs[bi], xs), axis=1)
+ xs = nx.concatenate(self.xs_, axis=0)
+ idx = nx.argmin(dist(Xs[bi], xs), axis=1)
# transport the source samples
for coupling in self.coupling_:
- transp = coupling / np.sum(
- coupling, 1)[:, None]
- transp[~ np.isfinite(transp)] = 0
- transp_Xs_.append(np.dot(transp, self.xt_))
+ transp = coupling / nx.sum(coupling, 1)[:, None]
+ transp[~ nx.isfinite(transp)] = 0
+ transp_Xs_.append(nx.dot(transp, self.xt_))
- transp_Xs_ = np.concatenate(transp_Xs_, axis=0)
+ transp_Xs_ = nx.concatenate(transp_Xs_, axis=0)
# define the transported points
transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - xs[idx, :]
transp_Xs.append(transp_Xs_)
- transp_Xs = np.concatenate(transp_Xs, axis=0)
+ transp_Xs = nx.concatenate(transp_Xs, axis=0)
return transp_Xs
@@ -2512,32 +2555,36 @@ class JCPOTTransport(BaseTransport):
"Optimal transport for multi-source domain adaptation under target shift",
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(ys=ys):
- yt = np.zeros((len(np.unique(np.concatenate(ys))), self.xt_.shape[0]))
+ yt = nx.zeros(
+ (len(nx.unique(nx.concatenate(ys))), self.xt_.shape[0]),
+ type_as=ys[0]
+ )
for i in range(len(ys)):
- ysTemp = label_normalization(np.copy(ys[i]))
- classes = np.unique(ysTemp)
+ ysTemp = label_normalization(nx.copy(ys[i]))
+ classes = nx.unique(ysTemp)
n = len(classes)
ns = len(ysTemp)
# perform label propagation
- transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
+ transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
if self.log:
D1 = self.log_['D1'][i]
else:
- D1 = np.zeros((n, ns))
+ D1 = nx.zeros((n, ns), type_as=transp)
for c in classes:
D1[int(c), ysTemp == c] = 1
# compute propagated labels
- yt = yt + np.dot(D1, transp) / len(ys)
+ yt = yt + nx.dot(D1, transp) / len(ys)
return yt.T
@@ -2555,14 +2602,15 @@ class JCPOTTransport(BaseTransport):
transp_ys : list of K array-like objects, shape K x (nk_source_samples, nb_classes)
A list of estimated soft source labels
"""
+ nx = self.nx
# check the necessary inputs parameters are here
if check_params(yt=yt):
transp_ys = []
- ytTemp = label_normalization(np.copy(yt))
- classes = np.unique(ytTemp)
+ ytTemp = label_normalization(nx.copy(yt))
+ classes = nx.unique(ytTemp)
n = len(classes)
- D1 = np.zeros((n, len(ytTemp)))
+ D1 = nx.zeros((n, len(ytTemp)), type_as=self.coupling_[0])
for c in classes:
D1[int(c), ytTemp == c] = 1
@@ -2570,12 +2618,12 @@ class JCPOTTransport(BaseTransport):
for i in range(len(self.xs_)):
# perform label propagation
- transp = self.coupling_[i] / np.sum(self.coupling_[i], 1)[:, None]
+ transp = self.coupling_[i] / nx.sum(self.coupling_[i], 1)[:, None]
# set nans to 0
- transp[~ np.isfinite(transp)] = 0
+ transp[~ nx.isfinite(transp)] = 0
# compute propagated labels
- transp_ys.append(np.dot(D1, transp.T).T)
+ transp_ys.append(nx.dot(D1, transp.T).T)
return transp_ys
diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py
index d9b6fa9..abf7fe0 100644
--- a/ot/lp/__init__.py
+++ b/ot/lp/__init__.py
@@ -225,6 +225,13 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.
+ .. note:: This function will cast the computed transport plan to the data type
+ of the provided input with the following priority: :math:`\mathbf{a}`,
+ then :math:`\mathbf{b}`, then :math:`\mathbf{M}` if marginals are not provided.
+ Casting to an integer tensor might result in a loss of precision.
+ If this behaviour is unwanted, please make sure to provide a
+ floating point input.
+
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
Parameters
@@ -290,12 +297,16 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
+ if len(a0) != 0:
+ type_as = a0
+ elif len(b0) != 0:
+ type_as = b0
+ else:
+ type_as = M0
nx = get_backend(M0, a0, b0)
# convert to numpy
- M = nx.to_numpy(M)
- a = nx.to_numpy(a)
- b = nx.to_numpy(b)
+ M, a, b = nx.to_numpy(M, a, b)
# ensure float64
a = np.asarray(a, dtype=np.float64)
@@ -330,15 +341,23 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
u, v = estimate_dual_null_weights(u, v, a, b, M)
result_code_string = check_result(result_code)
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
if log:
log = {}
log['cost'] = cost
- log['u'] = nx.from_numpy(u, type_as=a0)
- log['v'] = nx.from_numpy(v, type_as=b0)
+ log['u'] = nx.from_numpy(u, type_as=type_as)
+ log['v'] = nx.from_numpy(v, type_as=type_as)
log['warning'] = result_code_string
log['result_code'] = result_code
- return nx.from_numpy(G, type_as=M0), log
- return nx.from_numpy(G, type_as=M0)
+ return nx.from_numpy(G, type_as=type_as), log
+ return nx.from_numpy(G, type_as=type_as)
def emd2(a, b, M, processes=1,
@@ -364,6 +383,14 @@ def emd2(a, b, M, processes=1,
from all compatible backends. But the algorithm uses the C++ CPU backend
which can lead to copy overhead on GPU arrays.
+ .. note:: This function will cast the computed transport plan and
+ transportation loss to the data type of the provided input with the
+ following priority: :math:`\mathbf{a}`, then :math:`\mathbf{b}`,
+ then :math:`\mathbf{M}` if marginals are not provided.
+ Casting to an integer tensor might result in a loss of precision.
+ If this behaviour is unwanted, please make sure to provide a
+ floating point input.
+
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
Parameters
@@ -432,12 +459,16 @@ def emd2(a, b, M, processes=1,
a, b, M = list_to_array(a, b, M)
a0, b0, M0 = a, b, M
+ if len(a0) != 0:
+ type_as = a0
+ elif len(b0) != 0:
+ type_as = b0
+ else:
+ type_as = M0
nx = get_backend(M0, a0, b0)
# convert to numpy
- M = nx.to_numpy(M)
- a = nx.to_numpy(a)
- b = nx.to_numpy(b)
+ M, a, b = nx.to_numpy(M, a, b)
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
@@ -470,14 +501,22 @@ def emd2(a, b, M, processes=1,
result_code_string = check_result(result_code)
log = {}
- G = nx.from_numpy(G, type_as=M0)
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
+ G = nx.from_numpy(G, type_as=type_as)
if return_matrix:
log['G'] = G
- log['u'] = nx.from_numpy(u, type_as=a0)
- log['v'] = nx.from_numpy(v, type_as=b0)
+ log['u'] = nx.from_numpy(u, type_as=type_as)
+ log['v'] = nx.from_numpy(v, type_as=type_as)
log['warning'] = result_code_string
log['result_code'] = result_code
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
(a0, b0, M0), (log['u'], log['v'], G))
return [cost, log]
else:
@@ -491,10 +530,18 @@ def emd2(a, b, M, processes=1,
if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)
- G = nx.from_numpy(G, type_as=M0)
- cost = nx.set_gradients(nx.from_numpy(cost, type_as=M0),
- (a0, b0, M0), (nx.from_numpy(u, type_as=a0),
- nx.from_numpy(v, type_as=b0), G))
+ if not nx.is_floating_point(type_as):
+ warnings.warn(
+ "Input histogram consists of integer. The transport plan will be "
+ "casted accordingly, possibly resulting in a loss of precision. "
+ "If this behaviour is unwanted, please make sure your input "
+ "histogram consists of floating point elements.",
+ stacklevel=2
+ )
+ G = nx.from_numpy(G, type_as=type_as)
+ cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
+ (a0, b0, M0), (nx.from_numpy(u, type_as=type_as),
+ nx.from_numpy(v, type_as=type_as), G))
check_result(result_code)
return cost
diff --git a/ot/optim.py b/ot/optim.py
index f25e2c9..5a1d605 100644
--- a/ot/optim.py
+++ b/ot/optim.py
@@ -9,12 +9,19 @@ Generic solvers for regularized OT
# License: MIT License
import numpy as np
-from scipy.optimize.linesearch import scalar_search_armijo
+import warnings
from .lp import emd
from .bregman import sinkhorn
-from ot.utils import list_to_array
+from .utils import list_to_array
from .backend import get_backend
+with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ from scipy.optimize import scalar_search_armijo
+ except ImportError:
+ from scipy.optimize.linesearch import scalar_search_armijo
+
# The corresponding scipy function does not work for matrices
diff --git a/ot/unbalanced.py b/ot/unbalanced.py
index 15e180b..503cc1e 100644
--- a/ot/unbalanced.py
+++ b/ot/unbalanced.py
@@ -8,9 +8,9 @@ Regularized Unbalanced OT solvers
from __future__ import division
import warnings
-import numpy as np
-from scipy.special import logsumexp
+from .backend import get_backend
+from .utils import list_to_array
# from .utils import unif, dist
@@ -43,12 +43,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -70,12 +70,12 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -172,12 +172,12 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -198,7 +198,7 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
Returns
-------
- ot_distance : (n_hists,) ndarray
+ ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
log : dict
log dictionary returned only if `log` is `True`
@@ -239,9 +239,10 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epslilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
"""
- b = np.asarray(b, dtype=np.float64)
+ b = list_to_array(b)
if len(b.shape) < 2:
b = b[:, None]
+
if method.lower() == 'sinkhorn':
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
numItermax=numItermax,
@@ -291,12 +292,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`
If many, compute all the OT distances (a, b_i)
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -315,12 +316,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -354,17 +355,15 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
ot.optim.cg : General regularized OT
"""
-
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ M, a, b = list_to_array(M, a, b)
+ nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -377,17 +376,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, 1)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, 1), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
a = a.reshape(dim_a, 1)
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(M / (-reg))
fi = reg_m / (reg_m + reg)
@@ -397,14 +393,14 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
uprev = u
vprev = v
- Kv = K.dot(v)
+ Kv = nx.dot(K, v)
u = (a / Kv) ** fi
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
v = (b / Ktu) ** fi
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
@@ -412,8 +408,12 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
v = vprev
break
- err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.)
- err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.)
+ err_u = nx.max(nx.abs(u - uprev)) / max(
+ nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
+ )
+ err_v = nx.max(nx.abs(v - vprev)) / max(
+ nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.
+ )
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
@@ -426,11 +426,11 @@ def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
break
if log:
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
if n_hists: # return only loss
- res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
+ res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
@@ -475,12 +475,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Parameters
----------
- a : np.ndarray (dim_a,)
+ a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
- b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
+ b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
- M : np.ndarray (dim_a, dim_b)
+ M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
@@ -501,12 +501,12 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
Returns
-------
if n_hists == 1:
- - gamma : (dim_a, dim_b) ndarray
+ - gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- - ot_distance : (n_hists,) ndarray
+ - ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
@@ -538,17 +538,15 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
ot.optim.cg : General regularized OT
"""
-
- a = np.asarray(a, dtype=np.float64)
- b = np.asarray(b, dtype=np.float64)
- M = np.asarray(M, dtype=np.float64)
+ a, b, M = list_to_array(a, b, M)
+ nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
- a = np.ones(dim_a, dtype=np.float64) / dim_a
+ a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
- b = np.ones(dim_b, dtype=np.float64) / dim_b
+ b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
@@ -561,56 +559,52 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
- u = np.ones((dim_a, n_hists)) / dim_a
- v = np.ones((dim_b, n_hists)) / dim_b
+ u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
+ v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
a = a.reshape(dim_a, 1)
else:
- u = np.ones(dim_a) / dim_a
- v = np.ones(dim_b) / dim_b
+ u = nx.ones(dim_a, type_as=M) / dim_a
+ v = nx.ones(dim_b, type_as=M) / dim_b
# print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
- alpha = np.zeros(dim_a)
- beta = np.zeros(dim_b)
+ alpha = nx.zeros(dim_a, type_as=M)
+ beta = nx.zeros(dim_b, type_as=M)
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
- Kv = K.dot(v)
- f_alpha = np.exp(- alpha / (reg + reg_m))
- f_beta = np.exp(- beta / (reg + reg_m))
+ Kv = nx.dot(K, v)
+ f_alpha = nx.exp(- alpha / (reg + reg_m))
+ f_beta = nx.exp(- beta / (reg + reg_m))
if n_hists:
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((a / (Kv + 1e-16)) ** fi) * f_alpha
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
v = ((b / (Ktu + 1e-16)) ** fi) * f_beta
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
if n_hists:
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
+ alpha = alpha + reg * nx.log(nx.max(u, 1))
+ beta = beta + reg * nx.log(nx.max(v, 1))
else:
- alpha = alpha + reg * np.log(np.max(u))
- beta = beta + reg * np.log(np.max(v))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
-
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha = alpha + reg * nx.log(nx.max(u))
+ beta = beta + reg * nx.log(nx.max(v))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(v.shape, type_as=v)
+ Kv = nx.dot(K, v)
+
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
@@ -620,8 +614,9 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
if (cpt % 10 == 0 and not absorbing) or cpt == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(),
- 1.)
+ err = nx.max(nx.abs(u - uprev)) / max(
+ nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
+ )
if log:
log['err'].append(err)
if verbose:
@@ -636,25 +631,30 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000
"Try a larger entropy `reg` or a lower mass `reg_m`." +
"Or a larger absorption threshold `tau`.")
if n_hists:
- logu = alpha[:, None] / reg + np.log(u)
- logv = beta[:, None] / reg + np.log(v)
+ logu = alpha[:, None] / reg + nx.log(u)
+ logv = beta[:, None] / reg + nx.log(v)
else:
- logu = alpha / reg + np.log(u)
- logv = beta / reg + np.log(v)
+ logu = alpha / reg + nx.log(u)
+ logv = beta / reg + nx.log(v)
if log:
log['logu'] = logu
log['logv'] = logv
if n_hists: # return only loss
- res = logsumexp(np.log(M + 1e-100)[:, :, None] + logu[:, None, :] +
- logv[None, :, :] - M[:, :, None] / reg, axis=(0, 1))
- res = np.exp(res)
+ res = nx.logsumexp(
+ nx.log(M + 1e-100)[:, :, None]
+ + logu[:, None, :]
+ + logv[None, :, :]
+ - M[:, :, None] / reg,
+ axis=(0, 1)
+ )
+ res = nx.exp(res)
if log:
return res, log
else:
return res
else: # return OT matrix
- ot_matrix = np.exp(logu[:, None] + logv[None, :] - M / reg)
+ ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M / reg)
if log:
return ot_matrix, log
else:
@@ -683,9 +683,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
@@ -693,7 +693,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Marginal relaxation term > 0
tau : float
Stabilization threshold for log domain absorption.
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -708,7 +708,7 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -726,9 +726,12 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
"""
+ A, M = list_to_array(A, M)
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert(len(weights) == A.shape[1])
@@ -737,47 +740,43 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
fi = reg_m / (reg_m + reg)
- u = np.ones((dim, n_hists)) / dim
- v = np.ones((dim, n_hists)) / dim
+ u = nx.ones((dim, n_hists), type_as=A) / dim
+ v = nx.ones((dim, n_hists), type_as=A) / dim
# print(reg)
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
- K = np.empty(M.shape, dtype=M.dtype)
- np.divide(M, -reg, out=K)
- np.exp(K, out=K)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
- alpha = np.zeros(dim)
- beta = np.zeros(dim)
- q = np.ones(dim) / dim
+ alpha = nx.zeros(dim, type_as=A)
+ beta = nx.zeros(dim, type_as=A)
+ q = nx.ones(dim, type_as=A) / dim
for i in range(numItermax):
- qprev = q.copy()
- Kv = K.dot(v)
- f_alpha = np.exp(- alpha / (reg + reg_m))
- f_beta = np.exp(- beta / (reg + reg_m))
+ qprev = nx.copy(q)
+ Kv = nx.dot(K, v)
+ f_alpha = nx.exp(- alpha / (reg + reg_m))
+ f_beta = nx.exp(- beta / (reg + reg_m))
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
- Ktu = K.T.dot(u)
+ Ktu = nx.dot(K.T, u)
q = (Ktu ** (1 - fi)) * f_beta
- q = q.dot(weights) ** (1 / (1 - fi))
+ q = nx.dot(q, weights) ** (1 / (1 - fi))
Q = q[:, None]
v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
absorbing = False
- if (u > tau).any() or (v > tau).any():
+ if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
- alpha = alpha + reg * np.log(np.max(u, 1))
- beta = beta + reg * np.log(np.max(v, 1))
- K = np.exp((alpha[:, None] + beta[None, :] -
- M) / reg)
- v = np.ones_like(v)
- Kv = K.dot(v)
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ alpha = alpha + reg * nx.log(nx.max(u, 1))
+ beta = beta + reg * nx.log(nx.max(v, 1))
+ K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
+ v = nx.ones(v.shape, type_as=v)
+ Kv = nx.dot(K, v)
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
@@ -786,8 +785,9 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
if (i % 10 == 0 and not absorbing) or i == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
- err = abs(q - qprev).max() / max(abs(q).max(),
- abs(qprev).max(), 1.)
+ err = nx.max(nx.abs(q - qprev)) / max(
+ nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.
+ )
if log:
log['err'].append(err)
if verbose:
@@ -804,8 +804,8 @@ def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
"Or a larger absorption threshold `tau`.")
if log:
log['niter'] = i
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
@@ -833,15 +833,15 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -856,7 +856,7 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
@@ -874,40 +874,43 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
"""
+ A, M = list_to_array(A, M)
+ nx = get_backend(A, M)
+
dim, n_hists = A.shape
if weights is None:
- weights = np.ones(n_hists) / n_hists
+ weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert(len(weights) == A.shape[1])
if log:
log = {'err': []}
- K = np.exp(- M / reg)
+ K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
- v = np.ones((dim, n_hists))
- u = np.ones((dim, 1))
- q = np.ones(dim)
+ v = nx.ones((dim, n_hists), type_as=A)
+ u = nx.ones((dim, 1), type_as=A)
+ q = nx.ones(dim, type_as=A)
err = 1.
for i in range(numItermax):
- uprev = u.copy()
- vprev = v.copy()
- qprev = q.copy()
+ uprev = nx.copy(u)
+ vprev = nx.copy(v)
+ qprev = nx.copy(q)
- Kv = K.dot(v)
+ Kv = nx.dot(K, v)
u = (A / Kv) ** fi
- Ktu = K.T.dot(u)
- q = ((Ktu ** (1 - fi)).dot(weights))
+ Ktu = nx.dot(K.T, u)
+ q = nx.dot(Ktu ** (1 - fi), weights)
q = q ** (1 / (1 - fi))
Q = q[:, None]
v = (Q / Ktu) ** fi
- if (np.any(Ktu == 0.)
- or np.any(np.isnan(u)) or np.any(np.isnan(v))
- or np.any(np.isinf(u)) or np.any(np.isinf(v))):
+ if (nx.any(Ktu == 0.)
+ or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
+ or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
@@ -916,8 +919,9 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
q = qprev
break
# compute change in barycenter
- err = abs(q - qprev).max()
- err /= max(abs(q).max(), abs(qprev).max(), 1.)
+ err = nx.max(nx.abs(q - qprev)) / max(
+ nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0
+ )
if log:
log['err'].append(err)
# if barycenter did not change + at least 10 iterations - stop
@@ -932,8 +936,8 @@ def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
if log:
log['niter'] = i
- log['logu'] = np.log(u + 1e-300)
- log['logv'] = np.log(v + 1e-300)
+ log['logu'] = nx.log(u + 1e-300)
+ log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
@@ -961,15 +965,15 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
Parameters
----------
- A : np.ndarray (dim, n_hists)
+ A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
- M : np.ndarray (dim, dim)
+ M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
- weights : np.ndarray (n_hists,) optional
+ weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
@@ -984,7 +988,7 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
Returns
-------
- a : (dim,) ndarray
+ a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
diff --git a/ot/utils.py b/ot/utils.py
index 725ca00..a23ce7e 100644
--- a/ot/utils.py
+++ b/ot/utils.py
@@ -15,7 +15,7 @@ from scipy.spatial.distance import cdist
import sys
import warnings
from inspect import signature
-from .backend import get_backend
+from .backend import get_backend, Backend
__time_tic_toc = time.time()
@@ -51,7 +51,8 @@ def kernel(x1, x2, method='gaussian', sigma=1, **kwargs):
def laplacian(x):
r"""Compute Laplacian matrix"""
- L = np.diag(np.sum(x, axis=0)) - x
+ nx = get_backend(x)
+ L = nx.diag(nx.sum(x, axis=0)) - x
return L
@@ -136,7 +137,7 @@ def unif(n, type_as=None):
return np.ones((n,)) / n
else:
nx = get_backend(type_as)
- return nx.ones((n,)) / n
+ return nx.ones((n,), type_as=type_as) / n
def clean_zeros(a, b, M):
@@ -296,7 +297,8 @@ def cost_normalization(C, norm=None):
def dots(*args):
r""" dots function for multiple matrix multiply """
- return reduce(np.dot, args)
+ nx = get_backend(*args)
+ return reduce(nx.dot, args)
def label_normalization(y, start=0):
@@ -314,8 +316,9 @@ def label_normalization(y, start=0):
y : array-like, shape (`n1`, )
The input vector of labels normalized according to given start value.
"""
+ nx = get_backend(y)
- diff = np.min(np.unique(y)) - start
+ diff = nx.min(nx.unique(y)) - start
if diff != 0:
y -= diff
return y
@@ -482,6 +485,19 @@ class BaseEstimator(object):
arguments (no ``*args`` or ``**kwargs``).
"""
+ nx: Backend = None
+
+ def _get_backend(self, *arrays):
+ nx = get_backend(
+ *[input_ for input_ in arrays if input_ is not None]
+ )
+ if nx.__name__ in ("jax", "tf"):
+ raise TypeError(
+ """JAX or TF arrays have been received but domain
+ adaptation does not support those backend.""")
+ self.nx = nx
+ return nx
+
@classmethod
def _get_param_names(cls):
r"""Get parameter names for the estimator"""
diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py
index 6a42cfe..20f307a 100644
--- a/test/test_1d_solver.py
+++ b/test/test_1d_solver.py
@@ -66,9 +66,7 @@ def test_wasserstein_1d(nx):
rho_v = np.abs(rng.randn(n))
rho_v /= rho_v.sum()
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
# test 1 : wasserstein_1d should be close to scipy W_1 implementation
np.testing.assert_almost_equal(wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1),
@@ -98,9 +96,7 @@ def test_wasserstein_1d_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- rho_ub = nx.from_numpy(rho_u, type_as=tp)
- rho_vb = nx.from_numpy(rho_v, type_as=tp)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
@@ -122,17 +118,13 @@ def test_wasserstein_1d_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
nx.assert_same_dtype_device(xb, res)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
res = wasserstein_1d(xb, xb, rho_ub, rho_vb, p=1)
nx.assert_same_dtype_device(xb, res)
assert nx.dtype_device(res)[1].startswith("GPU")
@@ -190,9 +182,7 @@ def test_emd1d_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- rho_ub = nx.from_numpy(rho_u, type_as=tp)
- rho_vb = nx.from_numpy(rho_v, type_as=tp)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
@@ -214,9 +204,7 @@ def test_emd1d_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
nx.assert_same_dtype_device(xb, emd)
@@ -224,9 +212,7 @@ def test_emd1d_device_tf():
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- rho_ub = nx.from_numpy(rho_u)
- rho_vb = nx.from_numpy(rho_v)
+ xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v)
emd = ot.emd_1d(xb, xb, rho_ub, rho_vb)
emd2 = ot.emd2_1d(xb, xb, rho_ub, rho_vb)
nx.assert_same_dtype_device(xb, emd)
diff --git a/test/test_backend.py b/test/test_backend.py
index 027c4cd..311c075 100644
--- a/test/test_backend.py
+++ b/test/test_backend.py
@@ -218,6 +218,8 @@ def test_empty_backend():
with pytest.raises(NotImplementedError):
nx.argmax(M)
with pytest.raises(NotImplementedError):
+ nx.argmin(M)
+ with pytest.raises(NotImplementedError):
nx.mean(M)
with pytest.raises(NotImplementedError):
nx.std(M)
@@ -264,12 +266,27 @@ def test_empty_backend():
nx.device_type(M)
with pytest.raises(NotImplementedError):
nx._bench(lambda x: x, M, n_runs=1)
+ with pytest.raises(NotImplementedError):
+ nx.solve(M, v)
+ with pytest.raises(NotImplementedError):
+ nx.trace(M)
+ with pytest.raises(NotImplementedError):
+ nx.inv(M)
+ with pytest.raises(NotImplementedError):
+ nx.sqrtm(M)
+ with pytest.raises(NotImplementedError):
+ nx.isfinite(M)
+ with pytest.raises(NotImplementedError):
+ nx.array_equal(M, M)
+ with pytest.raises(NotImplementedError):
+ nx.is_floating_point(M)
def test_func_backends(nx):
rnd = np.random.RandomState(0)
M = rnd.randn(10, 3)
+ SquareM = rnd.randn(10, 10)
v = rnd.randn(3)
val = np.array([1.0])
@@ -288,6 +305,7 @@ def test_func_backends(nx):
lst_name = []
Mb = nx.from_numpy(M)
+ SquareMb = nx.from_numpy(SquareM)
vb = nx.from_numpy(v)
val = nx.from_numpy(val)
@@ -467,6 +485,10 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('argmax')
+ A = nx.argmin(Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('argmin')
+
A = nx.mean(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('mean')
@@ -529,7 +551,11 @@ def test_func_backends(nx):
A = nx.where(Mb >= nx.stack([nx.linspace(0, 1, 10)] * 3, axis=1), Mb, 0.0)
lst_b.append(nx.to_numpy(A))
- lst_name.append('where')
+ lst_name.append('where (cond, x, y)')
+
+ A = nx.where(nx.from_numpy(np.array([True, False])))
+ lst_b.append(nx.to_numpy(nx.stack(A)))
+ lst_name.append('where (cond)')
A = nx.copy(Mb)
lst_b.append(nx.to_numpy(A))
@@ -550,15 +576,47 @@ def test_func_backends(nx):
nx._bench(lambda x: x, M, n_runs=1)
+ A = nx.solve(SquareMb, Mb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('solve')
+
+ A = nx.trace(SquareMb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('trace')
+
+ A = nx.inv(SquareMb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append('matrix inverse')
+
+ A = nx.sqrtm(SquareMb.T @ SquareMb)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("matrix square root")
+
+ A = nx.concatenate([vb, nx.from_numpy(np.array([np.inf, np.nan]))], axis=0)
+ A = nx.isfinite(A)
+ lst_b.append(nx.to_numpy(A))
+ lst_name.append("isfinite")
+
+ assert not nx.array_equal(Mb, vb), "array_equal (shape)"
+ assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
+ assert not nx.array_equal(
+ Mb, Mb + nx.eye(*list(Mb.shape))
+ ), "array_equal (elements) - expected false"
+
+ assert nx.is_floating_point(Mb), "is_floating_point - expected true"
+ assert not nx.is_floating_point(
+ nx.from_numpy(np.array([0, 1, 2], dtype=int))
+ ), "is_floating_point - expected false"
+
lst_tot.append(lst_b)
lst_np = lst_tot[0]
lst_b = lst_tot[1]
for a1, a2, name in zip(lst_np, lst_b, lst_name):
- if not np.allclose(a1, a2):
- print('Assert fail on: ', name)
- assert np.allclose(a1, a2, atol=1e-7)
+ np.testing.assert_allclose(
+ a2, a1, atol=1e-7, err_msg=f'ASSERT FAILED ON: {name}'
+ )
def test_random_backends(nx):
diff --git a/test/test_bregman.py b/test/test_bregman.py
index 1419f9b..6c37984 100644
--- a/test/test_bregman.py
+++ b/test/test_bregman.py
@@ -155,8 +155,7 @@ def test_sinkhorn_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
- ab = nx.from_numpy(a)
- M_nx = nx.from_numpy(M)
+ ab, M_nx = nx.from_numpy(a, M)
Gb = ot.sinkhorn(ab, ab, M_nx, 1)
@@ -176,8 +175,7 @@ def test_sinkhorn2_backends(nx):
G = ot.sinkhorn(a, a, M, 1)
- ab = nx.from_numpy(a)
- M_nx = nx.from_numpy(M)
+ ab, M_nx = nx.from_numpy(a, M)
Gb = ot.sinkhorn2(ab, ab, M_nx, 1)
@@ -260,8 +258,7 @@ def test_sinkhorn_variants(nx):
M = ot.dist(x, x)
- ub = nx.from_numpy(u)
- M_nx = nx.from_numpy(M)
+ ub, M_nx = nx.from_numpy(u, M)
G = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
Gl = nx.to_numpy(ot.sinkhorn(ub, ub, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
@@ -298,8 +295,7 @@ def test_sinkhorn_variants_dtype_device(nx, method):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ub = nx.from_numpy(u, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ub, Mb = nx.from_numpy(u, M, type_as=tp)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
@@ -318,8 +314,7 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ub = nx.from_numpy(u, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ub, Mb = nx.from_numpy(u, M, type_as=tp)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
@@ -337,8 +332,7 @@ def test_sinkhorn2_variants_device_tf(method):
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- ub = nx.from_numpy(u)
- Mb = nx.from_numpy(M)
+ ub, Mb = nx.from_numpy(u, M)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
nx.assert_same_dtype_device(Mb, Gb)
@@ -346,8 +340,7 @@ def test_sinkhorn2_variants_device_tf(method):
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- ub = nx.from_numpy(u)
- Mb = nx.from_numpy(M)
+ ub, Mb = nx.from_numpy(u, M)
Gb = ot.sinkhorn(ub, ub, Mb, 1, method=method, stopThr=1e-10)
lossb = ot.sinkhorn2(ub, ub, Mb, 1, method=method, stopThr=1e-10)
nx.assert_same_dtype_device(Mb, Gb)
@@ -370,9 +363,7 @@ def test_sinkhorn_variants_multi_b(nx):
M = ot.dist(x, x)
- ub = nx.from_numpy(u)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M)
+ ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
Gl = nx.to_numpy(ot.sinkhorn(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
@@ -400,9 +391,7 @@ def test_sinkhorn2_variants_multi_b(nx):
M = ot.dist(x, x)
- ub = nx.from_numpy(u)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M)
+ ub, bb, M_nx = nx.from_numpy(u, b, M)
G = ot.sinkhorn2(u, b, M, 1, method='sinkhorn', stopThr=1e-10)
Gl = nx.to_numpy(ot.sinkhorn2(ub, bb, M_nx, 1, method='sinkhorn_log', stopThr=1e-10))
@@ -483,9 +472,7 @@ def test_barycenter(nx, method, verbose, warn):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- A_nx = nx.from_numpy(A)
- M_nx = nx.from_numpy(M)
- weights_nx = nx.from_numpy(weights)
+ A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights)
reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
@@ -523,9 +510,7 @@ def test_barycenter_debiased(nx, method, verbose, warn):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- A_nx = nx.from_numpy(A)
- M_nx = nx.from_numpy(M)
- weights_nx = nx.from_numpy(weights)
+ A_nx, M_nx, weights_nx = nx.from_numpy(A, M, weights)
# wasserstein
reg = 1e-2
@@ -594,9 +579,7 @@ def test_barycenter_stabilization(nx):
alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])
- A_nx = nx.from_numpy(A)
- M_nx = nx.from_numpy(M)
- weights_b = nx.from_numpy(weights)
+ A_nx, M_nx, weights_b = nx.from_numpy(A, M, weights)
# wasserstein
reg = 1e-2
@@ -697,11 +680,7 @@ def test_unmix(nx):
M0 /= M0.max()
h0 = ot.unif(2)
- ab = nx.from_numpy(a)
- Db = nx.from_numpy(D)
- M_nx = nx.from_numpy(M)
- M0b = nx.from_numpy(M0)
- h0b = nx.from_numpy(h0)
+ ab, Db, M_nx, M0b, h0b = nx.from_numpy(a, D, M, M0, h0)
# wasserstein
reg = 1e-3
@@ -727,12 +706,7 @@ def test_empirical_sinkhorn(nx):
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='euclidean')
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- X_sb = nx.from_numpy(X_s)
- X_tb = nx.from_numpy(X_t)
- M_nx = nx.from_numpy(M, type_as=ab)
- M_mb = nx.from_numpy(M_m, type_as=ab)
+ ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1))
sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1))
@@ -776,12 +750,7 @@ def test_lazy_empirical_sinkhorn(nx):
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='euclidean')
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- X_sb = nx.from_numpy(X_s)
- X_tb = nx.from_numpy(X_t)
- M_nx = nx.from_numpy(M, type_as=ab)
- M_mb = nx.from_numpy(M_m, type_as=ab)
+ ab, bb, X_sb, X_tb, M_nx, M_mb = nx.from_numpy(a, b, X_s, X_t, M, M_m)
f, g = ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 3), verbose=True)
f, g = nx.to_numpy(f), nx.to_numpy(g)
@@ -825,19 +794,13 @@ def test_empirical_sinkhorn_divergence(nx):
a = np.linspace(1, n, n)
a /= a.sum()
b = ot.unif(n)
- X_s = np.reshape(np.arange(n), (n, 1))
- X_t = np.reshape(np.arange(0, n * 2, 2), (n, 1))
+ X_s = np.reshape(np.arange(n, dtype=np.float64), (n, 1))
+ X_t = np.reshape(np.arange(0, n * 2, 2, dtype=np.float64), (n, 1))
M = ot.dist(X_s, X_t)
M_s = ot.dist(X_s, X_s)
M_t = ot.dist(X_t, X_t)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- X_sb = nx.from_numpy(X_s)
- X_tb = nx.from_numpy(X_t)
- M_nx = nx.from_numpy(M, type_as=ab)
- M_sb = nx.from_numpy(M_s, type_as=ab)
- M_tb = nx.from_numpy(M_t, type_as=ab)
+ ab, bb, X_sb, X_tb, M_nx, M_sb, M_tb = nx.from_numpy(a, b, X_s, X_t, M, M_s, M_t)
emp_sinkhorn_div = nx.to_numpy(ot.bregman.empirical_sinkhorn_divergence(X_sb, X_tb, 1, a=ab, b=bb))
sinkhorn_div = nx.to_numpy(
@@ -872,9 +835,7 @@ def test_stabilized_vs_sinkhorn_multidim(nx):
M /= np.median(M)
epsilon = 0.1
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M, type_as=ab)
+ ab, bb, M_nx = nx.from_numpy(a, b, M)
G_np, _ = ot.bregman.sinkhorn(a, b, M, reg=epsilon, method="sinkhorn", log=True)
G, log = ot.bregman.sinkhorn(ab, bb, M_nx, reg=epsilon,
@@ -936,9 +897,7 @@ def test_screenkhorn(nx):
x = rng.randn(n, 2)
M = ot.dist(x, x)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- M_nx = nx.from_numpy(M, type_as=ab)
+ ab, bb, M_nx = nx.from_numpy(a, b, M)
# sinkhorn
G_sink = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1e-1))
diff --git a/test/test_da.py b/test/test_da.py
index 9f2bb50..4bf0ab1 100644
--- a/test/test_da.py
+++ b/test/test_da.py
@@ -19,7 +19,32 @@ except ImportError:
nosklearn = True
-def test_sinkhorn_lpl1_transport_class():
+def test_class_jax_tf():
+ backends = []
+ from ot.backend import jax, tf
+ if jax:
+ backends.append(ot.backend.JaxBackend())
+ if tf:
+ backends.append(ot.backend.TensorflowBackend())
+
+ for nx in backends:
+ ns = 150
+ nt = 200
+
+ Xs, ys = make_data_classif('3gauss', ns)
+ Xt, yt = make_data_classif('3gauss2', nt)
+
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
+ otda = ot.da.SinkhornLpl1Transport()
+
+ with pytest.raises(TypeError):
+ otda.fit(Xs=Xs, ys=ys, Xt=Xt)
+
+
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_sinkhorn_lpl1_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -29,6 +54,8 @@ def test_sinkhorn_lpl1_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.SinkhornLpl1Transport()
# test its computed
@@ -44,15 +71,15 @@ def test_sinkhorn_lpl1_transport_class():
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -62,7 +89,7 @@ def test_sinkhorn_lpl1_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -85,24 +112,26 @@ def test_sinkhorn_lpl1_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornLpl1Transport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.SinkhornLpl1Transport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
# check that the coupling forbids mass transport between labeled source
# and labeled target samples
- mass_semi = np.sum(
+ mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
assert mass_semi == 0, "semisupervised mode not working"
-def test_sinkhorn_l1l2_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_sinkhorn_l1l2_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -112,6 +141,8 @@ def test_sinkhorn_l1l2_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.SinkhornL1l2Transport()
# test its computed
@@ -128,15 +159,15 @@ def test_sinkhorn_l1l2_transport_class():
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -156,7 +187,7 @@ def test_sinkhorn_l1l2_transport_class():
assert_equal(transp_ys.shape[0], ys.shape[0])
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -169,22 +200,22 @@ def test_sinkhorn_l1l2_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornL1l2Transport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.SinkhornL1l2Transport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
# check that the coupling forbids mass transport between labeled source
# and labeled target samples
- mass_semi = np.sum(
+ mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
- assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)),
rtol=1e-9, atol=1e-9)
# check everything runs well with log=True
@@ -193,7 +224,9 @@ def test_sinkhorn_l1l2_transport_class():
assert len(otda.log_.keys()) != 0
-def test_sinkhorn_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -203,6 +236,8 @@ def test_sinkhorn_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.SinkhornTransport()
# test its computed
@@ -219,15 +254,15 @@ def test_sinkhorn_transport_class():
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -247,7 +282,7 @@ def test_sinkhorn_transport_class():
assert_equal(transp_ys.shape[0], ys.shape[0])
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -260,19 +295,19 @@ def test_sinkhorn_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornTransport()
otda_unsup.fit(Xs=Xs, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.SinkhornTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
# check that the coupling forbids mass transport between labeled source
# and labeled target samples
- mass_semi = np.sum(
+ mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
assert mass_semi == 0, "semisupervised mode not working"
@@ -282,7 +317,9 @@ def test_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
-def test_unbalanced_sinkhorn_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_unbalanced_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -292,6 +329,8 @@ def test_unbalanced_sinkhorn_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.UnbalancedSinkhornTransport()
# test its computed
@@ -318,7 +357,7 @@ def test_unbalanced_sinkhorn_transport_class():
assert_equal(transp_ys.shape[0], ys.shape[0])
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -328,7 +367,7 @@ def test_unbalanced_sinkhorn_transport_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -341,12 +380,12 @@ def test_unbalanced_sinkhorn_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.SinkhornTransport()
otda_unsup.fit(Xs=Xs, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.SinkhornTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
@@ -357,7 +396,9 @@ def test_unbalanced_sinkhorn_transport_class():
assert len(otda.log_.keys()) != 0
-def test_emd_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_emd_transport_class(nx):
"""test_sinkhorn_transport
"""
@@ -367,6 +408,8 @@ def test_emd_transport_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.EMDTransport()
# test its computed
@@ -382,15 +425,15 @@ def test_emd_transport_class():
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
assert_equal(transp_Xs.shape, Xs.shape)
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -410,7 +453,7 @@ def test_emd_transport_class():
assert_equal(transp_ys.shape[0], ys.shape[0])
assert_equal(transp_ys.shape[1], len(np.unique(yt)))
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -423,28 +466,32 @@ def test_emd_transport_class():
# test unsupervised vs semi-supervised mode
otda_unsup = ot.da.EMDTransport()
otda_unsup.fit(Xs=Xs, ys=ys, Xt=Xt)
- n_unsup = np.sum(otda_unsup.cost_)
+ n_unsup = nx.sum(otda_unsup.cost_)
otda_semi = ot.da.EMDTransport()
otda_semi.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
assert_equal(otda_semi.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
- n_semisup = np.sum(otda_semi.cost_)
+ n_semisup = nx.sum(otda_semi.cost_)
# check that the cost matrix norms are indeed different
assert n_unsup != n_semisup, "semisupervised mode not working"
# check that the coupling forbids mass transport between labeled source
# and labeled target samples
- mass_semi = np.sum(
+ mass_semi = nx.sum(
otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max])
mass_semi = otda_semi.coupling_[otda_semi.cost_ == otda_semi.limit_max]
# we need to use a small tolerance here, otherwise the test breaks
- assert_allclose(mass_semi, np.zeros_like(mass_semi),
+ assert_allclose(nx.to_numpy(mass_semi), np.zeros(list(mass_semi.shape)),
rtol=1e-2, atol=1e-2)
-def test_mapping_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+@pytest.mark.parametrize("kernel", ["linear", "gaussian"])
+@pytest.mark.parametrize("bias", ["unbiased", "biased"])
+def test_mapping_transport_class(nx, kernel, bias):
"""test_mapping_transport
"""
@@ -455,101 +502,29 @@ def test_mapping_transport_class():
Xt, yt = make_data_classif('3gauss2', nt)
Xs_new, _ = make_data_classif('3gauss', ns + 1)
- ##########################################################################
- # kernel == linear mapping tests
- ##########################################################################
+ Xs, Xt, Xs_new = nx.from_numpy(Xs, Xt, Xs_new)
- # check computation and dimensions if bias == False
- otda = ot.da.MappingTransport(kernel="linear", bias=False)
+ # Mapping tests
+ bias = bias == "biased"
+ otda = ot.da.MappingTransport(kernel=kernel, bias=bias)
otda.fit(Xs=Xs, Xt=Xt)
assert hasattr(otda, "coupling_")
assert hasattr(otda, "mapping_")
assert hasattr(otda, "log_")
assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(otda.mapping_.shape, ((Xs.shape[1], Xt.shape[1])))
+ S = Xs.shape[0] if kernel == "gaussian" else Xs.shape[1] # if linear
+ if bias:
+ S += 1
+ assert_equal(otda.mapping_.shape, ((S, Xt.shape[1])))
# test margin constraints
mu_s = unif(ns)
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
-
- # test transform
- transp_Xs = otda.transform(Xs=Xs)
- assert_equal(transp_Xs.shape, Xs.shape)
-
- transp_Xs_new = otda.transform(Xs_new)
-
- # check that the oos method is working
- assert_equal(transp_Xs_new.shape, Xs_new.shape)
-
- # check computation and dimensions if bias == True
- otda = ot.da.MappingTransport(kernel="linear", bias=True)
- otda.fit(Xs=Xs, Xt=Xt)
- assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(otda.mapping_.shape, ((Xs.shape[1] + 1, Xt.shape[1])))
-
- # test margin constraints
- mu_s = unif(ns)
- mu_t = unif(nt)
- assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
-
- # test transform
- transp_Xs = otda.transform(Xs=Xs)
- assert_equal(transp_Xs.shape, Xs.shape)
-
- transp_Xs_new = otda.transform(Xs_new)
-
- # check that the oos method is working
- assert_equal(transp_Xs_new.shape, Xs_new.shape)
-
- ##########################################################################
- # kernel == gaussian mapping tests
- ##########################################################################
-
- # check computation and dimensions if bias == False
- otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
- otda.fit(Xs=Xs, Xt=Xt)
-
- assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(otda.mapping_.shape, ((Xs.shape[0], Xt.shape[1])))
-
- # test margin constraints
- mu_s = unif(ns)
- mu_t = unif(nt)
- assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
-
- # test transform
- transp_Xs = otda.transform(Xs=Xs)
- assert_equal(transp_Xs.shape, Xs.shape)
-
- transp_Xs_new = otda.transform(Xs_new)
-
- # check that the oos method is working
- assert_equal(transp_Xs_new.shape, Xs_new.shape)
-
- # check computation and dimensions if bias == True
- otda = ot.da.MappingTransport(kernel="gaussian", bias=True)
- otda.fit(Xs=Xs, Xt=Xt)
- assert_equal(otda.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
- assert_equal(otda.mapping_.shape, ((Xs.shape[0] + 1, Xt.shape[1])))
-
- # test margin constraints
- mu_s = unif(ns)
- mu_t = unif(nt)
- assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
- assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
@@ -561,29 +536,39 @@ def test_mapping_transport_class():
assert_equal(transp_Xs_new.shape, Xs_new.shape)
# check everything runs well with log=True
- otda = ot.da.MappingTransport(kernel="gaussian", log=True)
+ otda = ot.da.MappingTransport(kernel=kernel, bias=bias, log=True)
otda.fit(Xs=Xs, Xt=Xt)
assert len(otda.log_.keys()) != 0
+
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_mapping_transport_class_specific_seed(nx):
# check that it does not crash when derphi is very close to 0
+ ns = 20
+ nt = 30
np.random.seed(39)
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
- otda.fit(Xs=Xs, Xt=Xt)
+ otda.fit(Xs=nx.from_numpy(Xs), Xt=nx.from_numpy(Xt))
np.random.seed(None)
-def test_linear_mapping():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_linear_mapping(nx):
ns = 150
nt = 200
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
- A, b = ot.da.OT_mapping_linear(Xs, Xt)
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
- Xst = Xs.dot(A) + b
+ A, b = ot.da.OT_mapping_linear(Xsb, Xtb)
+
+ Xst = nx.to_numpy(nx.dot(Xsb, A) + b)
Ct = np.cov(Xt.T)
Cst = np.cov(Xst.T)
@@ -591,22 +576,26 @@ def test_linear_mapping():
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-def test_linear_mapping_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_linear_mapping_class(nx):
ns = 150
nt = 200
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xsb, Xtb = nx.from_numpy(Xs, Xt)
+
otmap = ot.da.LinearTransport()
- otmap.fit(Xs=Xs, Xt=Xt)
+ otmap.fit(Xs=Xsb, Xt=Xtb)
assert hasattr(otmap, "A_")
assert hasattr(otmap, "B_")
assert hasattr(otmap, "A1_")
assert hasattr(otmap, "B1_")
- Xst = otmap.transform(Xs=Xs)
+ Xst = nx.to_numpy(otmap.transform(Xs=Xsb))
Ct = np.cov(Xt.T)
Cst = np.cov(Xst.T)
@@ -614,7 +603,9 @@ def test_linear_mapping_class():
np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2)
-def test_jcpot_transport_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_jcpot_transport_class(nx):
"""test_jcpot_transport
"""
@@ -627,6 +618,8 @@ def test_jcpot_transport_class():
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs1, ys1, Xs2, ys2, Xt, yt = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt, yt)
+
Xs = [Xs1, Xs2]
ys = [ys1, ys2]
@@ -649,19 +642,24 @@ def test_jcpot_transport_class():
for i in range(len(Xs)):
# test margin constraints w.r.t. uniform target weights for each coupling matrix
assert_allclose(
- np.sum(otda.coupling_[i], axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_[i], axis=0)), mu_t, rtol=1e-3, atol=1e-3)
# test margin constraints w.r.t. modified source weights for each source domain
assert_allclose(
- np.dot(otda.log_['D1'][i], np.sum(otda.coupling_[i], axis=1)), otda.proportions_, rtol=1e-3,
- atol=1e-3)
+ nx.to_numpy(
+ nx.dot(otda.log_['D1'][i], nx.sum(otda.coupling_[i], axis=1))
+ ),
+ nx.to_numpy(otda.proportions_),
+ rtol=1e-3,
+ atol=1e-3
+ )
# test transform
transp_Xs = otda.transform(Xs=Xs)
[assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
- Xs_new, _ = make_data_classif('3gauss', ns1 + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns1 + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -670,15 +668,16 @@ def test_jcpot_transport_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
- assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+ assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(*ys))))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
- [assert_equal(x.shape[0], y.shape[0]) for x, y in zip(transp_ys, ys)]
- [assert_equal(x.shape[1], len(np.unique(y))) for x, y in zip(transp_ys, ys)]
+ for x, y in zip(transp_ys, ys):
+ assert_equal(x.shape[0], y.shape[0])
+ assert_equal(x.shape[1], len(np.unique(nx.to_numpy(y))))
-def test_jcpot_barycenter():
+def test_jcpot_barycenter(nx):
"""test_jcpot_barycenter
"""
@@ -695,19 +694,23 @@ def test_jcpot_barycenter():
Xs1, ys1 = make_data_classif('2gauss_prop', ns1, nz=sigma, p=ps1)
Xs2, ys2 = make_data_classif('2gauss_prop', ns2, nz=sigma, p=ps2)
- Xt, yt = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt)
+ Xt, _ = make_data_classif('2gauss_prop', nt, nz=sigma, p=pt)
- Xs = [Xs1, Xs2]
- ys = [ys1, ys2]
+ Xs1b, ys1b, Xs2b, ys2b, Xtb = nx.from_numpy(Xs1, ys1, Xs2, ys2, Xt)
- prop = ot.bregman.jcpot_barycenter(Xs, ys, Xt, reg=.5, metric='sqeuclidean',
+ Xsb = [Xs1b, Xs2b]
+ ysb = [ys1b, ys2b]
+
+ prop = ot.bregman.jcpot_barycenter(Xsb, ysb, Xtb, reg=.5, metric='sqeuclidean',
numItermax=10000, stopThr=1e-9, verbose=False, log=False)
- np.testing.assert_allclose(prop, [1 - pt, pt], rtol=1e-3, atol=1e-3)
+ np.testing.assert_allclose(nx.to_numpy(prop), [1 - pt, pt], rtol=1e-3, atol=1e-3)
@pytest.mark.skipif(nosklearn, reason="No sklearn available")
-def test_emd_laplace_class():
+@pytest.skip_backend("jax")
+@pytest.skip_backend("tf")
+def test_emd_laplace_class(nx):
"""test_emd_laplace_transport
"""
ns = 150
@@ -716,6 +719,8 @@ def test_emd_laplace_class():
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
+ Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)
+
otda = ot.da.EMDLaplaceTransport(reg_lap=0.01, max_iter=1000, tol=1e-9, verbose=False, log=True)
# test its computed
@@ -732,15 +737,15 @@ def test_emd_laplace_class():
mu_t = unif(nt)
assert_allclose(
- np.sum(otda.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=0)), mu_t, rtol=1e-3, atol=1e-3)
assert_allclose(
- np.sum(otda.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
+ nx.to_numpy(nx.sum(otda.coupling_, axis=1)), mu_s, rtol=1e-3, atol=1e-3)
# test transform
transp_Xs = otda.transform(Xs=Xs)
[assert_equal(x.shape, y.shape) for x, y in zip(transp_Xs, Xs)]
- Xs_new, _ = make_data_classif('3gauss', ns + 1)
+ Xs_new = nx.from_numpy(make_data_classif('3gauss', ns + 1)[0])
transp_Xs_new = otda.transform(Xs_new)
# check that the oos method is working
@@ -750,7 +755,7 @@ def test_emd_laplace_class():
transp_Xt = otda.inverse_transform(Xt=Xt)
assert_equal(transp_Xt.shape, Xt.shape)
- Xt_new, _ = make_data_classif('3gauss2', nt + 1)
+ Xt_new = nx.from_numpy(make_data_classif('3gauss2', nt + 1)[0])
transp_Xt_new = otda.inverse_transform(Xt=Xt_new)
# check that the oos method is working
@@ -763,9 +768,9 @@ def test_emd_laplace_class():
# check label propagation
transp_yt = otda.transform_labels(ys)
assert_equal(transp_yt.shape[0], yt.shape[0])
- assert_equal(transp_yt.shape[1], len(np.unique(ys)))
+ assert_equal(transp_yt.shape[1], len(np.unique(nx.to_numpy(ys))))
# check inverse label propagation
transp_ys = otda.inverse_transform_labels(yt)
assert_equal(transp_ys.shape[0], ys.shape[0])
- assert_equal(transp_ys.shape[1], len(np.unique(yt)))
+ assert_equal(transp_ys.shape[1], len(np.unique(nx.to_numpy(yt))))
diff --git a/test/test_gromov.py b/test/test_gromov.py
index 0dcf2da..12fd2b9 100644
--- a/test/test_gromov.py
+++ b/test/test_gromov.py
@@ -35,11 +35,7 @@ def test_gromov(nx):
C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
- G0b = nx.from_numpy(G0)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', G0=G0, verbose=True)
Gb = nx.to_numpy(ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True))
@@ -105,11 +101,7 @@ def test_gromov_dtype_device(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- C1b = nx.from_numpy(C1, type_as=tp)
- C2b = nx.from_numpy(C2, type_as=tp)
- pb = nx.from_numpy(p, type_as=tp)
- qb = nx.from_numpy(q, type_as=tp)
- G0b = nx.from_numpy(G0, type_as=tp)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0, type_as=tp)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
@@ -136,11 +128,7 @@ def test_gromov_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
- G0b = nx.from_numpy(G0)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', G0=G0b, verbose=True)
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', G0=G0b, log=False)
nx.assert_same_dtype_device(C1b, Gb)
@@ -148,11 +136,7 @@ def test_gromov_device_tf():
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
- G0b = nx.from_numpy(G0b)
+ C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
nx.assert_same_dtype_device(C1b, Gb)
@@ -222,10 +206,7 @@ def test_entropic_gromov(nx):
C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
G = ot.gromov.entropic_gromov_wasserstein(
C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True)
@@ -285,10 +266,7 @@ def test_entropic_gromov_dtype_device(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- C1b = nx.from_numpy(C1, type_as=tp)
- C2b = nx.from_numpy(C2, type_as=tp)
- pb = nx.from_numpy(p, type_as=tp)
- qb = nx.from_numpy(q, type_as=tp)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp)
Gb = ot.gromov.entropic_gromov_wasserstein(
C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
@@ -320,10 +298,7 @@ def test_pointwise_gromov(nx):
C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
def loss(x, y):
return np.abs(x - y)
@@ -381,10 +356,7 @@ def test_sampled_gromov(nx):
C1 /= C1.max()
C2 /= C2.max()
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
+ C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q)
def loss(x, y):
return np.abs(x - y)
@@ -423,19 +395,15 @@ def test_gromov_barycenter(nx):
n_samples = 3
p = ot.unif(n_samples)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- p1b = nx.from_numpy(p1)
- p2b = nx.from_numpy(p2)
- pb = nx.from_numpy(p)
+ C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)
Cb = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42
)
Cbb = nx.to_numpy(ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42
))
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))
@@ -443,15 +411,15 @@ def test_gromov_barycenter(nx):
# test of gromov_barycenters with `log` on
Cb_, err_ = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cbb_, errb_ = ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'square_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
- np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err']))
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
Cb2 = ot.gromov.gromov_barycenters(
@@ -468,15 +436,15 @@ def test_gromov_barycenter(nx):
# test of gromov_barycenters with `log` on
Cb2_, err2_ = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cb2b_, err2b_ = ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
- 'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
+ 'kl_loss', max_iter=100, tol=1e-3, verbose=False, random_state=42, log=True
)
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
- np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err']))
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
@@ -495,11 +463,7 @@ def test_gromov_entropic_barycenter(nx):
n_samples = 2
p = ot.unif(n_samples)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- p1b = nx.from_numpy(p1)
- p2b = nx.from_numpy(p2)
- pb = nx.from_numpy(p)
+ C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)
Cb = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
@@ -523,7 +487,7 @@ def test_gromov_entropic_barycenter(nx):
)
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
- np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
+ np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err']))
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))
Cb2 = ot.gromov.entropic_gromov_barycenters(
@@ -548,7 +512,7 @@ def test_gromov_entropic_barycenter(nx):
)
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
- np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
+ np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err']))
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))
@@ -578,12 +542,7 @@ def test_fgw(nx):
M = ot.dist(ys, yt)
M /= M.max()
- Mb = nx.from_numpy(M)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- pb = nx.from_numpy(p)
- qb = nx.from_numpy(q)
- G0b = nx.from_numpy(G0)
+ Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True)
Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, G0=G0b, log=True)
@@ -681,13 +640,7 @@ def test_fgw_barycenter(nx):
n_samples = 3
p = ot.unif(n_samples)
- ysb = nx.from_numpy(ys)
- ytb = nx.from_numpy(yt)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- p1b = nx.from_numpy(p1)
- p2b = nx.from_numpy(p2)
- pb = nx.from_numpy(p)
+ ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)
Xb, Cb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False,
@@ -731,10 +684,8 @@ def test_gromov_wasserstein_linear_unmixing(nx):
Cdict = np.stack([C1, C2])
p = ot.unif(n)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- Cdictb = nx.from_numpy(Cdict)
- pb = nx.from_numpy(p)
+ C1b, C2b, Cdictb, pb = nx.from_numpy(C1, C2, Cdict, p)
+
tol = 10**(-5)
# Tests without regularization
reg = 0.
@@ -764,8 +715,8 @@ def test_gromov_wasserstein_linear_unmixing(nx):
np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06)
np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06)
- np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06)
- np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
np.testing.assert_allclose(C1b_emb.shape, (n, n))
np.testing.assert_allclose(C2b_emb.shape, (n, n))
@@ -798,8 +749,8 @@ def test_gromov_wasserstein_linear_unmixing(nx):
np.testing.assert_allclose(unmixing2, [0., 1.], atol=1e-01)
np.testing.assert_allclose(C1_emb, nx.to_numpy(C1b_emb), atol=1e-06)
np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-06)
- np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06)
- np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
np.testing.assert_allclose(C1b_emb.shape, (n, n))
np.testing.assert_allclose(C2b_emb.shape, (n, n))
@@ -824,13 +775,14 @@ def test_gromov_wasserstein_dictionary_learning(nx):
dataset_means = [C.mean() for C in Cs]
np.random.seed(0)
Cdict_init = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(n_atoms, shape, shape))
+
if projection == 'nonnegative_symmetric':
Cdict_init = 0.5 * (Cdict_init + Cdict_init.transpose((0, 2, 1)))
Cdict_init[Cdict_init < 0.] = 0.
- Csb = [nx.from_numpy(C) for C in Cs]
- psb = [nx.from_numpy(p) for p in ps]
- qb = nx.from_numpy(q)
- Cdict_initb = nx.from_numpy(Cdict_init)
+
+ Csb = nx.from_numpy(*Cs)
+ psb = nx.from_numpy(*ps)
+ qb, Cdict_initb = nx.from_numpy(q, Cdict_init)
# Test: compare reconstruction error using initial dictionary and dictionary learned using this initialization
# > Compute initial reconstruction of samples on this random dictionary without backend
@@ -882,6 +834,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
)
total_reconstruction_b += reconstruction
+ total_reconstruction_b = nx.to_numpy(total_reconstruction_b)
np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
@@ -924,6 +877,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
)
total_reconstruction_b_bis += reconstruction
+ total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis)
np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05)
np.testing.assert_allclose(Cdict_bis, nx.to_numpy(Cdictb_bis), atol=1e-03)
@@ -969,6 +923,7 @@ def test_gromov_wasserstein_dictionary_learning(nx):
)
total_reconstruction_b_bis2 += reconstruction
+ total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
np.testing.assert_allclose(total_reconstruction_b_bis2, total_reconstruction_bis2, atol=1e-05)
@@ -985,12 +940,8 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx):
Ydict = np.stack([F, F])
p = ot.unif(n)
- C1b = nx.from_numpy(C1)
- C2b = nx.from_numpy(C2)
- Fb = nx.from_numpy(F)
- Cdictb = nx.from_numpy(Cdict)
- Ydictb = nx.from_numpy(Ydict)
- pb = nx.from_numpy(p)
+ C1b, C2b, Fb, Cdictb, Ydictb, pb = nx.from_numpy(C1, C2, F, Cdict, Ydict, p)
+
# Tests without regularization
reg = 0.
@@ -1022,8 +973,8 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx):
np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03)
np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03)
np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03)
- np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06)
- np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
np.testing.assert_allclose(C1b_emb.shape, (n, n))
np.testing.assert_allclose(C2b_emb.shape, (n, n))
@@ -1058,8 +1009,8 @@ def test_fused_gromov_wasserstein_linear_unmixing(nx):
np.testing.assert_allclose(C2_emb, nx.to_numpy(C2b_emb), atol=1e-03)
np.testing.assert_allclose(Y1_emb, nx.to_numpy(Y1b_emb), atol=1e-03)
np.testing.assert_allclose(Y2_emb, nx.to_numpy(Y2b_emb), atol=1e-03)
- np.testing.assert_allclose(reconstruction1, reconstruction1b, atol=1e-06)
- np.testing.assert_allclose(reconstruction2, reconstruction2b, atol=1e-06)
+ np.testing.assert_allclose(reconstruction1, nx.to_numpy(reconstruction1b), atol=1e-06)
+ np.testing.assert_allclose(reconstruction2, nx.to_numpy(reconstruction2b), atol=1e-06)
np.testing.assert_allclose(C1b_emb.shape, (n, n))
np.testing.assert_allclose(C2b_emb.shape, (n, n))
@@ -1093,12 +1044,10 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
dataset_feature_means = np.stack([Y.mean(axis=0) for Y in Ys])
Ydict_init = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(n_atoms, shape, 2))
- Csb = [nx.from_numpy(C) for C in Cs]
- Ysb = [nx.from_numpy(Y) for Y in Ys]
- psb = [nx.from_numpy(p) for p in ps]
- qb = nx.from_numpy(q)
- Cdict_initb = nx.from_numpy(Cdict_init)
- Ydict_initb = nx.from_numpy(Ydict_init)
+ Csb = nx.from_numpy(*Cs)
+ Ysb = nx.from_numpy(*Ys)
+ psb = nx.from_numpy(*ps)
+ qb, Cdict_initb, Ydict_initb = nx.from_numpy(q, Cdict_init, Ydict_init)
# Test: Compute initial reconstruction of samples on this random dictionary
alpha = 0.5
@@ -1151,6 +1100,7 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
)
total_reconstruction_b += reconstruction
+ total_reconstruction_b = nx.to_numpy(total_reconstruction_b)
np.testing.assert_array_less(total_reconstruction_b, initial_total_reconstruction)
np.testing.assert_allclose(total_reconstruction_b, total_reconstruction, atol=1e-05)
np.testing.assert_allclose(Cdict, nx.to_numpy(Cdictb), atol=1e-03)
@@ -1192,6 +1142,8 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
tol_outer=tol, tol_inner=tol, max_iter_outer=20, max_iter_inner=200
)
total_reconstruction_b_bis += reconstruction
+
+ total_reconstruction_b_bis = nx.to_numpy(total_reconstruction_b_bis)
np.testing.assert_allclose(total_reconstruction_b_bis, total_reconstruction_b, atol=1e-05)
# Test: without using adam optimizer, with log and verbose set to True
@@ -1237,4 +1189,5 @@ def test_fused_gromov_wasserstein_dictionary_learning(nx):
total_reconstruction_b_bis2 += reconstruction
# > Compare results with/without backend
+ total_reconstruction_b_bis2 = nx.to_numpy(total_reconstruction_b_bis2)
np.testing.assert_allclose(total_reconstruction_bis2, total_reconstruction_b_bis2, atol=1e-05)
diff --git a/test/test_optim.py b/test/test_optim.py
index 41f9cbe..67e9d13 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -32,9 +32,7 @@ def test_conditional_gradient(nx):
def fb(G):
return 0.5 * nx.sum(G ** 2)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ ab, bb, Mb = nx.from_numpy(a, b, M)
reg = 1e-1
@@ -74,9 +72,7 @@ def test_conditional_gradient_itermax(nx):
def fb(G):
return 0.5 * nx.sum(G ** 2)
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ ab, bb, Mb = nx.from_numpy(a, b, M)
reg = 1e-1
@@ -118,9 +114,7 @@ def test_generalized_conditional_gradient(nx):
reg1 = 1e-3
reg2 = 1e-1
- ab = nx.from_numpy(a)
- bb = nx.from_numpy(b)
- Mb = nx.from_numpy(M, type_as=ab)
+ ab, bb, Mb = nx.from_numpy(a, b, M)
G, log = ot.optim.gcg(a, b, M, reg1, reg2, f, df, verbose=True, log=True)
Gb, log = ot.optim.gcg(ab, bb, Mb, reg1, reg2, fb, df, verbose=True, log=True)
@@ -142,9 +136,12 @@ def test_line_search_armijo(nx):
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
old_fval = -123
+
+ xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk)
+
# Should not throw an exception and return 0. for alpha
alpha, a, b = ot.optim.line_search_armijo(
- lambda x: 1, nx.from_numpy(xk), nx.from_numpy(pk), nx.from_numpy(gfk), old_fval
+ lambda x: 1, xkb, pkb, gfkb, old_fval
)
alpha_np, anp, bnp = ot.optim.line_search_armijo(
lambda x: 1, xk, pk, gfk, old_fval
diff --git a/test/test_ot.py b/test/test_ot.py
index 3e2d845..bb258e2 100644
--- a/test/test_ot.py
+++ b/test/test_ot.py
@@ -47,8 +47,7 @@ def test_emd_backends(nx):
G = ot.emd(a, a, M)
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
@@ -68,8 +67,7 @@ def test_emd2_backends(nx):
val = ot.emd2(a, a, M)
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
valb = ot.emd2(ab, ab, Mb)
@@ -90,8 +88,7 @@ def test_emd_emd2_types_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- ab = nx.from_numpy(a, type_as=tp)
- Mb = nx.from_numpy(M, type_as=tp)
+ ab, Mb = nx.from_numpy(a, M, type_as=tp)
Gb = ot.emd(ab, ab, Mb)
@@ -117,8 +114,7 @@ def test_emd_emd2_devices_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
w = ot.emd2(ab, ab, Mb)
nx.assert_same_dtype_device(Mb, Gb)
@@ -126,8 +122,7 @@ def test_emd_emd2_devices_tf():
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- ab = nx.from_numpy(a)
- Mb = nx.from_numpy(M)
+ ab, Mb = nx.from_numpy(a, M)
Gb = ot.emd(ab, ab, Mb)
w = ot.emd2(ab, ab, Mb)
nx.assert_same_dtype_device(Mb, Gb)
@@ -310,8 +305,8 @@ def test_free_support_barycenter_backends(nx):
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
- measures_locations2 = [nx.from_numpy(x) for x in measures_locations]
- measures_weights2 = [nx.from_numpy(x) for x in measures_weights]
+ measures_locations2 = nx.from_numpy(*measures_locations)
+ measures_weights2 = nx.from_numpy(*measures_weights)
X_init2 = nx.from_numpy(X_init)
X2 = ot.lp.free_support_barycenter(measures_locations2, measures_weights2, X_init2)
diff --git a/test/test_sliced.py b/test/test_sliced.py
index 91e0961..08ab4fb 100644
--- a/test/test_sliced.py
+++ b/test/test_sliced.py
@@ -123,9 +123,7 @@ def test_sliced_backend(nx):
n_projections = 20
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
val0 = ot.sliced_wasserstein_distance(x, y, projections=P)
@@ -153,9 +151,7 @@ def test_sliced_backend_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- yb = nx.from_numpy(y, type_as=tp)
- Pb = nx.from_numpy(P, type_as=tp)
+ xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
@@ -174,17 +170,13 @@ def test_sliced_backend_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
@@ -203,9 +195,7 @@ def test_max_sliced_backend(nx):
n_projections = 20
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
val0 = ot.max_sliced_wasserstein_distance(x, y, projections=P)
@@ -233,9 +223,7 @@ def test_max_sliced_backend_type_devices(nx):
for tp in nx.__type_list__:
print(nx.dtype_device(tp))
- xb = nx.from_numpy(x, type_as=tp)
- yb = nx.from_numpy(y, type_as=tp)
- Pb = nx.from_numpy(P, type_as=tp)
+ xb, yb, Pb = nx.from_numpy(x, y, P, type_as=tp)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
@@ -254,17 +242,13 @@ def test_max_sliced_backend_device_tf():
# Check that everything stays on the CPU
with tf.device("/CPU:0"):
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
if len(tf.config.list_physical_devices('GPU')) > 0:
# Check that everything happens on the GPU
- xb = nx.from_numpy(x)
- yb = nx.from_numpy(y)
- Pb = nx.from_numpy(P)
+ xb, yb, Pb = nx.from_numpy(x, y, P)
valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb)
nx.assert_same_dtype_device(xb, valb)
assert nx.dtype_device(valb)[1].startswith("GPU")
diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py
index e8349d1..db59504 100644
--- a/test/test_unbalanced.py
+++ b/test/test_unbalanced.py
@@ -9,11 +9,9 @@ import ot
import pytest
from ot.unbalanced import barycenter_unbalanced
-from scipy.special import logsumexp
-
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_convergence(method):
+def test_unbalanced_convergence(nx, method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
@@ -28,36 +26,51 @@ def test_unbalanced_convergence(method):
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
log=True,
verbose=True)
- loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method=method,
- verbose=True)
+ loss = nx.to_numpy(ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method=method, verbose=True
+ ))
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
- logb = np.log(b + 1e-16)
- loga = np.log(a + 1e-16)
- logKtu = logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
- logKv = logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
+ logb = nx.log(b + 1e-16)
+ loga = nx.log(a + 1e-16)
+ logKtu = nx.logsumexp(log["logu"][None, :] - M.T / epsilon, axis=1)
+ logKv = nx.logsumexp(log["logv"][None, :] - M / epsilon, axis=1)
v_final = fi * (logb - logKtu)
u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
# check if sinkhorn_unbalanced2 returns the correct loss
- np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
+ np.testing.assert_allclose(nx.to_numpy(nx.sum(G * M)), loss, atol=1e-5)
+
+ # check in case no histogram is provided
+ M_np = nx.to_numpy(M)
+ a_np, b_np = np.array([]), np.array([])
+ a, b = nx.from_numpy(a_np, b_np)
+
+ G = ot.unbalanced.sinkhorn_unbalanced(
+ a, b, M, reg=epsilon, reg_m=reg_m, method=method, verbose=True
+ )
+ G_np = ot.unbalanced.sinkhorn_unbalanced(
+ a_np, b_np, M_np, reg=epsilon, reg_m=reg_m, method=method, verbose=True
+ )
+ np.testing.assert_allclose(G_np, nx.to_numpy(G))
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_multiple_inputs(method):
+def test_unbalanced_multiple_inputs(nx, method):
# test generalized sinkhorn for unbalanced OT
n = 100
rng = np.random.RandomState(42)
@@ -72,6 +85,8 @@ def test_unbalanced_multiple_inputs(method):
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
reg_m=reg_m,
method=method,
@@ -80,23 +95,24 @@ def test_unbalanced_multiple_inputs(method):
# check fixed point equations
# in log-domain
fi = reg_m / (reg_m + epsilon)
- logb = np.log(b + 1e-16)
- loga = np.log(a + 1e-16)[:, None]
- logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
- axis=0)
- logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ logb = nx.log(b + 1e-16)
+ loga = nx.log(a + 1e-16)[:, None]
+ logKtu = nx.logsumexp(
+ log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0
+ )
+ logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
v_final = fi * (logb - logKtu)
u_final = fi * (loga - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
assert len(loss) == b.shape[1]
-def test_stabilized_vs_sinkhorn():
+def test_stabilized_vs_sinkhorn(nx):
# test if stable version matches sinkhorn
n = 100
@@ -112,19 +128,27 @@ def test_stabilized_vs_sinkhorn():
M /= np.median(M)
epsilon = 0.1
reg_m = 1.
- G, log = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, reg=epsilon,
- method="sinkhorn_stabilized",
- reg_m=reg_m,
- log=True,
- verbose=True)
- G2, log2 = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method="sinkhorn", log=True)
+
+ ab, bb, Mb = nx.from_numpy(a, b, M)
+
+ G, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ ab, bb, Mb, epsilon, reg_m, method="sinkhorn_stabilized", log=True
+ )
+ G2, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ ab, bb, Mb, epsilon, reg_m, method="sinkhorn", log=True
+ )
+ G2_np, _ = ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method="sinkhorn", log=True
+ )
+ G = nx.to_numpy(G)
+ G2 = nx.to_numpy(G2)
np.testing.assert_allclose(G, G2, atol=1e-5)
+ np.testing.assert_allclose(G2, G2_np, atol=1e-5)
@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized"])
-def test_unbalanced_barycenter(method):
+def test_unbalanced_barycenter(nx, method):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -138,25 +162,29 @@ def test_unbalanced_barycenter(method):
epsilon = 1.
reg_m = 1.
- q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method=method, log=True, verbose=True)
+ A, M = nx.from_numpy(A, M)
+
+ q, log = barycenter_unbalanced(
+ A, M, reg=epsilon, reg_m=reg_m, method=method, log=True, verbose=True
+ )
# check fixed point equations
fi = reg_m / (reg_m + epsilon)
- logA = np.log(A + 1e-16)
- logq = np.log(q + 1e-16)[:, None]
- logKtu = logsumexp(log["logu"][:, None, :] - M[:, :, None] / epsilon,
- axis=0)
- logKv = logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
+ logA = nx.log(A + 1e-16)
+ logq = nx.log(q + 1e-16)[:, None]
+ logKtu = nx.logsumexp(
+ log["logu"][:, None, :] - M[:, :, None] / epsilon, axis=0
+ )
+ logKv = nx.logsumexp(log["logv"][None, :] - M[:, :, None] / epsilon, axis=1)
v_final = fi * (logq - logKtu)
u_final = fi * (logA - logKv)
np.testing.assert_allclose(
- u_final, log["logu"], atol=1e-05)
+ nx.to_numpy(u_final), nx.to_numpy(log["logu"]), atol=1e-05)
np.testing.assert_allclose(
- v_final, log["logv"], atol=1e-05)
+ nx.to_numpy(v_final), nx.to_numpy(log["logv"]), atol=1e-05)
-def test_barycenter_stabilized_vs_sinkhorn():
+def test_barycenter_stabilized_vs_sinkhorn(nx):
# test generalized sinkhorn for unbalanced OT barycenter
n = 100
rng = np.random.RandomState(42)
@@ -170,21 +198,24 @@ def test_barycenter_stabilized_vs_sinkhorn():
epsilon = 0.5
reg_m = 10
- qstable, log = barycenter_unbalanced(A, M, reg=epsilon,
- reg_m=reg_m, log=True,
- tau=100,
- method="sinkhorn_stabilized",
- verbose=True
- )
- q, log = barycenter_unbalanced(A, M, reg=epsilon, reg_m=reg_m,
- method="sinkhorn",
- log=True)
+ Ab, Mb = nx.from_numpy(A, M)
- np.testing.assert_allclose(
- q, qstable, atol=1e-05)
+ qstable, _ = barycenter_unbalanced(
+ Ab, Mb, reg=epsilon, reg_m=reg_m, log=True, tau=100,
+ method="sinkhorn_stabilized", verbose=True
+ )
+ q, _ = barycenter_unbalanced(
+ Ab, Mb, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True
+ )
+ q_np, _ = barycenter_unbalanced(
+ A, M, reg=epsilon, reg_m=reg_m, method="sinkhorn", log=True
+ )
+ q, qstable = nx.to_numpy(q, qstable)
+ np.testing.assert_allclose(q, qstable, atol=1e-05)
+ np.testing.assert_allclose(q, q_np, atol=1e-05)
-def test_wrong_method():
+def test_wrong_method(nx):
n = 10
rng = np.random.RandomState(42)
@@ -199,19 +230,20 @@ def test_wrong_method():
epsilon = 1.
reg_m = 1.
+ a, b, M = nx.from_numpy(a, b, M)
+
with pytest.raises(ValueError):
- ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
- reg_m=reg_m,
- method='badmethod',
- log=True,
- verbose=True)
+ ot.unbalanced.sinkhorn_unbalanced(
+ a, b, M, reg=epsilon, reg_m=reg_m, method='badmethod',
+ log=True, verbose=True
+ )
with pytest.raises(ValueError):
- ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, reg_m,
- method='badmethod',
- verbose=True)
+ ot.unbalanced.sinkhorn_unbalanced2(
+ a, b, M, epsilon, reg_m, method='badmethod', verbose=True
+ )
-def test_implemented_methods():
+def test_implemented_methods(nx):
IMPLEMENTED_METHODS = ['sinkhorn', 'sinkhorn_stabilized']
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling']
NOT_VALID_TOKENS = ['foo']
@@ -228,6 +260,9 @@ def test_implemented_methods():
M = ot.dist(x, x)
epsilon = 1.
reg_m = 1.
+
+ a, b, M, A = nx.from_numpy(a, b, M, A)
+
for method in IMPLEMENTED_METHODS:
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, reg_m,
method=method)
diff --git a/test/test_weak.py b/test/test_weak.py
index c4c3278..945efb1 100644
--- a/test/test_weak.py
+++ b/test/test_weak.py
@@ -45,9 +45,7 @@ def test_weak_ot_bakends(nx):
G = ot.weak_optimal_transport(xs, xt, u, u)
- xs2 = nx.from_numpy(xs)
- xt2 = nx.from_numpy(xt)
- u2 = nx.from_numpy(u)
+ xs2, xt2, u2 = nx.from_numpy(xs, xt, u)
G2 = ot.weak_optimal_transport(xs2, xt2, u2, u2)