#![allow(dead_code)] pub type Shape = [usize; D]; pub type Index = Shape; pub struct Tensor { data: Vec, shape: Shape, strides: Shape } pub type Tensor1 = Tensor; pub type Tensor2 = Tensor; pub type Tensor3 = Tensor; pub type Tensor4 = Tensor; pub type Index1 = Index<1>; pub type Index2 = Index<2>; pub type Index3 = Index<3>; pub type Index4 = Index<4>; impl Tensor { pub fn new(shape: Shape, x: T) -> Self { if D == 0 { panic!("Empty shape not allowed."); } let mut len = shape[D-1]; let mut strides: Shape = [0; D]; strides[D-1] = 1; for d in (1..D).rev() { // d=D-1, …, 1. strides[d-1] = shape[d]*strides[d]; len *= shape[d-1]; } if len == 0 { panic!("Empty dimensions not allowed."); } Self { data: vec![x; len], shape: shape, strides: strides, } } pub fn new_from(shape: Shape, x: Vec) -> Self { if D == 0 { panic!("Empty shape not allowed."); } let mut len = shape[D-1]; let mut strides: Shape = [0; D]; strides[D-1] = 1; for d in (1..D).rev() { // d=D-1, …, 1. strides[d-1] = shape[d]*strides[d]; len *= shape[d-1]; } if len == 0 { panic!("Empty dimensions not allowed."); } if len != x.len() { panic!("Vector of length {} cannot fill tensor with {} entries.", x.len(), len); } Self { data: x, shape: shape, strides: strides, } } #[inline(always)] fn flatten_idx(self: & Self, idx: & Index) -> usize { // NOTE: This is a very hot code path. Should benchmark versus explicit loop. idx.iter().zip(self.strides.iter()).fold(0, |sum, (i, s)| sum + i*s) } fn bound_check_panic(self: & Self, idx: & Index) -> () { for d in 0..D { let i = *(unsafe { idx.get_unchecked(d) }); if i >= self.shape[d] { panic!("{}-dimensional tensor index is out of bounds in dimension {} ({} >= {}).", D, d, i, self.shape[d]) } } } pub fn in_bounds(self: & Self, idx: & Index) -> bool { for d in 0..D { let i = *(unsafe { idx.get_unchecked(d) }); if i >= self.shape[d] { return false; } } true } pub fn shape(self: & Self) -> & Shape { &self.shape } pub fn el(self: & Self, idx: & Index) -> Option<& T> { if self.in_bounds(idx) { Some(unsafe { self.el_unchecked(idx) }) } else { None } } pub unsafe fn el_unchecked(self: & Self, idx: & Index) -> & T { self.data.get_unchecked(self.flatten_idx(idx)) } pub unsafe fn el_unchecked_mut(self: &mut Self, idx: & Index) -> &mut T { let flat_idx = self.flatten_idx(idx); self.data.get_unchecked_mut(flat_idx) } pub fn flat_len(self: & Self) -> usize { self.data.len() } pub fn size(self: & Self) -> usize { self.flat_len()*std::mem::size_of::() } pub fn fill_with(self: &mut Self, x: & [T]) -> () { // Already panics on size mismatch. self.data.copy_from_slice(x) } pub fn fill(self: &mut Self, x: T) -> () { self.data.fill(x); } pub fn data(self: & Self) -> & [T] { & self.data } } impl Tensor { pub fn dirty_print(self: & Self) { for i in 0..self.shape[0] { for j in 0..self.shape[1] { print!("{} ", self[[i, j]]); } println!(""); } } } impl std::ops::Index> for Tensor { type Output = T; fn index(self: & Self, idx: Index) -> & Self::Output { self.bound_check_panic(&idx); unsafe { self.el_unchecked(&idx) } } } impl std::ops::Index<& Index> for Tensor { type Output = T; fn index(self: & Self, idx: & Index) -> & Self::Output { self.bound_check_panic(idx); unsafe { self.el_unchecked(idx) } } } impl std::ops::IndexMut> for Tensor { fn index_mut(self: &mut Self, idx: Index) -> &mut Self::Output { self.bound_check_panic(&idx); unsafe { self.el_unchecked_mut(&idx) } } } impl std::ops::IndexMut<& Index> for Tensor { fn index_mut(self: &mut Self, idx: & Index) -> &mut Self::Output { self.bound_check_panic(idx); unsafe { self.el_unchecked_mut(idx) } } } impl IntoIterator for Tensor { type Item = T; type IntoIter = std::vec::IntoIter; fn into_iter(self) -> Self::IntoIter { self.data.into_iter() } } // FIXME: Should have a proper IntoIter implementing IntoIter for &'a Tensor. // Note: Tensor is also sliceable (due to the Index implementations)