Cord/crates/cord-sparse/src/index.rs

176 lines
4.2 KiB
Rust

/// Iterator over the downward-closed multi-index set
/// {(i_1, ..., i_d) | i_k >= 0, i_1 + ... + i_d <= bound}.
///
/// Does not iterate over the last dimension explicitly; instead
/// reports how many values the last dimension can take at each
/// position. This is the key to the unidirectional principle:
/// multiply along one dimension at a time, cycling indices.
#[derive(Debug, Clone)]
pub struct BoundedSumIter {
d: usize,
bound: usize,
head: Vec<usize>,
head_sum: usize,
valid: bool,
}
impl BoundedSumIter {
pub fn new(d: usize, bound: usize) -> Self {
assert!(d >= 1);
Self {
d,
bound,
head: vec![0; d.saturating_sub(1)],
head_sum: 0,
valid: true,
}
}
pub fn last_dim_count(&self) -> usize {
self.bound - self.head_sum + 1
}
pub fn next(&mut self) {
if self.d == 1 {
self.valid = false;
return;
}
let tail = self.d - 2;
if self.bound > self.head_sum {
self.head_sum += 1;
self.head[tail] += 1;
} else {
let mut dim = tail as isize;
while dim >= 0 && self.head[dim as usize] == 0 {
dim -= 1;
}
if dim > 0 {
let d = dim as usize;
self.head_sum -= self.head[d] - 1;
self.head[d] = 0;
self.head[d - 1] += 1;
} else if dim == 0 {
self.head[0] = 0;
self.head_sum = 0;
self.valid = false;
} else {
self.valid = false;
}
}
}
pub fn valid(&self) -> bool {
self.valid
}
pub fn reset(&mut self) {
self.head.fill(0);
self.head_sum = 0;
self.valid = true;
}
pub fn first_index(&self) -> usize {
if self.head.is_empty() { 0 } else { self.head[0] }
}
pub fn index_at(&self, dim: usize) -> usize {
self.head[dim]
}
pub fn dim(&self) -> usize {
self.d
}
pub fn first_index_bound(&self) -> usize {
self.bound + 1
}
pub fn index_bounds(&self) -> Vec<usize> {
vec![self.bound + 1; self.d]
}
pub fn num_values(&self) -> usize {
binom(self.bound + self.d, self.d)
}
pub fn num_values_per_first_index(&self) -> Vec<usize> {
(0..=self.bound)
.map(|i| binom((self.bound - i) + (self.d - 1), self.d - 1))
.collect()
}
pub fn go_to_end(&mut self) {
self.head.fill(0);
self.head_sum = 0;
self.valid = false;
}
/// Cycle: last dimension moves to front. For a bounded-sum set
/// the constraint is symmetric, so the iterator is identical.
pub fn cycle(&self) -> Self {
let mut c = self.clone();
c.reset();
c
}
}
pub fn binom(n: usize, k: usize) -> usize {
let k = k.min(n.saturating_sub(k));
let mut prod: usize = 1;
for i in 0..k {
prod = prod * (n - i) / (i + 1);
}
prod
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn binom_basic() {
assert_eq!(binom(5, 2), 10);
assert_eq!(binom(10, 3), 120);
assert_eq!(binom(0, 0), 1);
}
#[test]
fn iter_count_2d() {
let it = BoundedSumIter::new(2, 3);
assert_eq!(it.num_values(), binom(5, 2)); // C(5,2)=10
}
#[test]
fn iter_traversal_2d() {
let mut it = BoundedSumIter::new(2, 3);
let mut total = 0;
while it.valid() {
total += it.last_dim_count();
it.next();
}
assert_eq!(total, 10);
}
#[test]
fn iter_traversal_3d() {
let mut it = BoundedSumIter::new(3, 2);
// {(i,j,k) | i+j+k <= 2} has C(5,3)=10 elements
let mut total = 0;
while it.valid() {
total += it.last_dim_count();
it.next();
}
assert_eq!(total, 10);
}
#[test]
fn num_values_per_first() {
let it = BoundedSumIter::new(3, 2);
let counts = it.num_values_per_first_index();
// first_index=0: C(4,2)=6, first_index=1: C(3,2)=3, first_index=2: C(2,2)=1
assert_eq!(counts, vec![6, 3, 1]);
}
}