use crate::index::BoundedSumIter; /// Multi-dimensional vector indexed by a bounded-sum multi-index set. /// /// Storage is split by first dimension for cache-friendly access /// during the unidirectional matrix-vector products: data[i] contains /// all entries whose multi-index starts with i, in lexicographic order. /// /// Values are i64 fixed-point throughout. #[derive(Debug, Clone)] pub struct MultiDimVec { pub data: Vec>, iter: BoundedSumIter, } impl MultiDimVec { pub fn new(iter: BoundedSumIter) -> Self { let sizes = iter.num_values_per_first_index(); let data = sizes.iter().map(|&s| vec![0i64; s]).collect(); Self { data, iter } } pub fn iter(&self) -> &BoundedSumIter { &self.iter } pub fn reset_with_iter(&mut self, iter: BoundedSumIter) { if self.iter.dim() == iter.dim() && self.iter.first_index_bound() == iter.first_index_bound() { self.iter = iter; return; } let sizes = iter.num_values_per_first_index(); self.data.resize(sizes.len(), Vec::new()); for (i, &s) in sizes.iter().enumerate() { self.data[i].resize(s, 0); } self.iter = iter; } pub fn swap(&mut self, other: &mut MultiDimVec) { std::mem::swap(&mut self.data, &mut other.data); std::mem::swap(&mut self.iter, &mut other.iter); } pub fn add_assign(&mut self, other: &MultiDimVec) { for (row, other_row) in self.data.iter_mut().zip(other.data.iter()) { for (a, b) in row.iter_mut().zip(other_row.iter()) { *a += b; } } } pub fn sub_assign(&mut self, other: &MultiDimVec) { for (row, other_row) in self.data.iter_mut().zip(other.data.iter()) { for (a, b) in row.iter_mut().zip(other_row.iter()) { *a -= b; } } } pub fn squared_l2_norm(&self) -> i128 { let mut sum: i128 = 0; for row in &self.data { for &v in row { sum += (v as i128) * (v as i128); } } sum } } /// Iterate over all entries yielding (multi_index, &value). pub struct MultiDimVecIter<'a> { vec: &'a MultiDimVec, jump: BoundedSumIter, last_dim_idx: usize, last_dim_count: usize, first_idx: usize, tail_counter: usize, } impl<'a> MultiDimVecIter<'a> { pub fn new(vec: &'a MultiDimVec) -> Self { let mut jump = vec.iter.clone(); jump.reset(); let count = if jump.valid() { jump.last_dim_count() } else { 0 }; Self { vec, jump, last_dim_idx: 0, last_dim_count: count, first_idx: 0, tail_counter: 0, } } } impl<'a> Iterator for MultiDimVecIter<'a> { type Item = i64; fn next(&mut self) -> Option { if !self.jump.valid() { return None; } let val = self.vec.data[self.first_idx][self.tail_counter]; self.last_dim_idx += 1; self.tail_counter += 1; if self.last_dim_idx >= self.last_dim_count { self.last_dim_idx = 0; self.jump.next(); if self.jump.valid() { let new_first = self.jump.first_index(); if new_first != self.first_idx { self.first_idx = new_first; self.tail_counter = 0; } self.last_dim_count = self.jump.last_dim_count(); } } Some(val) } } #[cfg(test)] mod tests { use super::*; #[test] fn storage_size() { let it = BoundedSumIter::new(3, 2); let v = MultiDimVec::new(it.clone()); let total: usize = v.data.iter().map(|r| r.len()).sum(); assert_eq!(total, it.num_values()); } #[test] fn iter_all_values() { let it = BoundedSumIter::new(3, 2); let mut v = MultiDimVec::new(it.clone()); // Fill with sequential values let mut val = 1i64; for row in &mut v.data { for cell in row.iter_mut() { *cell = val; val += 1; } } let collected: Vec = MultiDimVecIter::new(&v).collect(); assert_eq!(collected.len(), it.num_values()); assert_eq!(collected[0], 1); } }