diff options
Diffstat (limited to '11/src/tensor.rs')
-rw-r--r-- | 11/src/tensor.rs | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/11/src/tensor.rs b/11/src/tensor.rs new file mode 100644 index 0000000..6898701 --- /dev/null +++ b/11/src/tensor.rs @@ -0,0 +1,175 @@ +#![allow(dead_code)] + +pub type Shape<const D: usize> = [usize; D]; +pub type Index<const D: usize> = Shape<D>; + +pub struct Tensor<T, const D: usize> { + data: Vec<T>, + shape: Shape<D>, + strides: Shape<D> +} + +pub type Tensor1<T> = Tensor<T, 1>; +pub type Tensor2<T> = Tensor<T, 2>; +pub type Tensor3<T> = Tensor<T, 3>; +pub type Tensor4<T> = Tensor<T, 4>; +pub type Index1 = Index<1>; +pub type Index2 = Index<2>; +pub type Index3 = Index<3>; +pub type Index4 = Index<4>; + + +impl<T: Copy, const D: usize> Tensor<T, D> { + pub fn new(shape: Shape<D>, x: T) -> Self { + let dim = D; + if dim == 0 { panic!("Empty shape not allowed."); } + + let mut len = shape[dim-1]; + let mut strides: Shape<D> = [0; D]; + strides[dim-1] = 1; + for d in (1..dim).rev() { // d=dim-1, …, 1. + strides[d-1] = shape[d]*strides[d]; + len *= shape[d-1]; + } + if len == 0 { panic!("Empty dimensions not allowed."); } + + Self { + data: vec![x; len], + shape: shape, + strides: strides, + } + } + + pub fn new_from(shape: Shape<D>, x: Vec<T>) -> Self { + let dim = D; + if dim == 0 { panic!("Empty shape not allowed."); } + + let mut len = shape[dim-1]; + let mut strides: Shape<D> = [0; D]; + strides[dim-1] = 1; + for d in (1..dim).rev() { // d=dim-1, …, 1. + strides[d-1] = shape[d]*strides[d]; + len *= shape[d-1]; + } + if len == 0 { panic!("Empty dimensions not allowed."); } + + if len != x.len() { panic!("Vector of length {} cannot fill tensor with {} entries.", x.len(), len); } + + Self { + data: x, + shape: shape, + strides: strides, + } + } + + + #[inline(always)] + fn flatten_idx(self: & Self, idx: & Index<D>) -> usize { + // NOTE: This is a very hot code path. Should benchmark versus explicit loop. + idx.iter().zip(self.strides.iter()).fold(0, |sum, (i, s)| sum + i*s) + } + + fn bound_check_panic(self: & Self, idx: & Index<D>) -> () { + for d in 0..self.dim() { + let i = *(unsafe { idx.get_unchecked(d) }); + if i >= self.shape[d] { + panic!("{}-dimensional tensor index is out of bounds in dimension {} ({} >= {}).", self.dim(), d, i, self.shape[d]) + } + } + } + + pub fn in_bounds(self: & Self, idx: & Index<D>) -> bool { + for d in 0..self.dim() { + let i = *(unsafe { idx.get_unchecked(d) }); + if i >= self.shape[d] { + return false; + } + } + true + } + + pub fn dim(self: & Self) -> usize { D } + + pub fn shape(self: & Self) -> & Shape<D> { &self.shape } + + pub fn el(self: & Self, idx: & Index<D>) -> Option<& T> { + if self.in_bounds(idx) { Some(unsafe { self.el_unchecked(idx) }) } + else { None } + } + + pub unsafe fn el_unchecked(self: & Self, idx: & Index<D>) -> & T { self.data.get_unchecked(self.flatten_idx(idx)) } + + pub unsafe fn el_unchecked_mut(self: &mut Self, idx: & Index<D>) -> &mut T { + let flat_idx = self.flatten_idx(idx); + self.data.get_unchecked_mut(flat_idx) + } + + pub fn flat_len(self: & Self) -> usize { self.data.len() } + pub fn size(self: & Self) -> usize { self.flat_len()*std::mem::size_of::<T>() } + + pub fn fill_with(self: &mut Self, x: & [T]) -> () { + // Already panics on size mismatch. + self.data.copy_from_slice(x) + } + + pub fn fill(self: &mut Self, x: T) -> () { + self.data.fill(x); + } + + pub fn data(self: & Self) -> & [T] { &self.data } +} + +impl<T: Copy + std::fmt::Display> Tensor<T, 2> { + pub fn dirty_print(self: & Self) { + for i in 0..self.shape[0] { + for j in 0..self.shape[1] { + print!("{} ", self[[i, j]]); + } + println!(""); + } + } +} + +impl<T: Copy, const D: usize> std::ops::Index<Index<D>> for Tensor<T, D> { + type Output = T; + + fn index(self: & Self, idx: Index<D>) -> & Self::Output { + self.bound_check_panic(&idx); + unsafe { self.el_unchecked(&idx) } + } +} + +impl<T: Copy, const D: usize> std::ops::Index<& Index<D>> for Tensor<T, D> { + type Output = T; + + fn index(self: & Self, idx: & Index<D>) -> & Self::Output { + self.bound_check_panic(idx); + unsafe { self.el_unchecked(idx) } + } +} + +impl<T: Copy, const D: usize> std::ops::IndexMut<Index<D>> for Tensor<T, D> { + fn index_mut(self: &mut Self, idx: Index<D>) -> &mut Self::Output { + self.bound_check_panic(&idx); + unsafe { self.el_unchecked_mut(&idx) } + } +} + +impl<T: Copy, const D: usize> std::ops::IndexMut<& Index<D>> for Tensor<T, D> { + fn index_mut(self: &mut Self, idx: & Index<D>) -> &mut Self::Output { + self.bound_check_panic(idx); + unsafe { self.el_unchecked_mut(idx) } + } +} + +impl<T: Copy, const D: usize> IntoIterator for Tensor<T, D> { + type Item = T; + type IntoIter = std::vec::IntoIter<Self::Item>; + + fn into_iter(self) -> Self::IntoIter { self.data.into_iter() } +} + + +// FIXME: Should have a proper IntoIter implementing IntoIter for &'a Tensor<T, D>. + +// Note: Tensor is also sliceable (due to the Index implementations) |