diff --git a/roaring/src/bitmap/iter.rs b/roaring/src/bitmap/iter.rs index 1593ab40..4076950d 100644 --- a/roaring/src/bitmap/iter.rs +++ b/roaring/src/bitmap/iter.rs @@ -96,6 +96,43 @@ impl Iterator for Iter<'_> { }; init } + + fn count(self) -> usize + where + Self: Sized, + { + let mut count = self.front.map_or(0, Iterator::count); + count += self.containers.map(|container| container.len() as usize).sum::(); + count += self.back.map_or(0, Iterator::count); + count + } + + fn nth(&mut self, n: usize) -> Option { + let mut n = n; + let nth_advance = |it: &mut container::Iter| { + let len = it.len(); + if n < len { + it.nth(n) + } else { + n -= len; + None + } + }; + if let Some(x) = and_then_or_clear(&mut self.front, nth_advance) { + return Some(x); + } + for container in self.containers.by_ref() { + let len = container.len() as usize; + if n < len { + let mut front_iter = container.into_iter(); + let result = front_iter.nth(n); + self.front = Some(front_iter); + return result; + } + n -= len; + } + self.back.as_mut().and_then(|it| it.nth(n)) + } } impl DoubleEndedIterator for Iter<'_> { @@ -128,6 +165,33 @@ impl DoubleEndedIterator for Iter<'_> { }; init } + + fn nth_back(&mut self, n: usize) -> Option { + let mut n = n; + let nth_advance = |it: &mut container::Iter| { + let len = it.len(); + if n < len { + it.nth_back(n) + } else { + n -= len; + None + } + }; + if let Some(x) = and_then_or_clear(&mut self.back, nth_advance) { + return Some(x); + } + for container in self.containers.by_ref().rev() { + let len = container.len() as usize; + if n < len { + let mut front_iter = container.into_iter(); + let result = front_iter.nth_back(n); + self.back = Some(front_iter); + return result; + } + n -= len; + } + self.front.as_mut().and_then(|it| it.nth_back(n)) + } } #[cfg(target_pointer_width = "64")] @@ -170,6 +234,43 @@ impl Iterator for IntoIter { }; init } + + fn count(self) -> usize + where + Self: Sized, + { + let mut count = self.front.map_or(0, Iterator::count); + count += self.containers.map(|container| container.len() as usize).sum::(); + count += self.back.map_or(0, Iterator::count); + count + } + + fn nth(&mut self, n: usize) -> Option { + let mut n = n; + let nth_advance = |it: &mut container::Iter| { + let len = it.len(); + if n < len { + it.nth(n) + } else { + n -= len; + None + } + }; + if let Some(x) = and_then_or_clear(&mut self.front, nth_advance) { + return Some(x); + } + for container in self.containers.by_ref() { + let len = container.len() as usize; + if n < len { + let mut front_iter = container.into_iter(); + let result = front_iter.nth(n); + self.front = Some(front_iter); + return result; + } + n -= len; + } + self.back.as_mut().and_then(|it| it.nth(n)) + } } impl DoubleEndedIterator for IntoIter { @@ -202,6 +303,33 @@ impl DoubleEndedIterator for IntoIter { }; init } + + fn nth_back(&mut self, n: usize) -> Option { + let mut n = n; + let nth_advance = |it: &mut container::Iter| { + let len = it.len(); + if n < len { + it.nth_back(n) + } else { + n -= len; + None + } + }; + if let Some(x) = and_then_or_clear(&mut self.back, nth_advance) { + return Some(x); + } + for container in self.containers.by_ref().rev() { + let len = container.len() as usize; + if n < len { + let mut front_iter = container.into_iter(); + let result = front_iter.nth_back(n); + self.back = Some(front_iter); + return result; + } + n -= len; + } + self.front.as_mut().and_then(|it| it.nth_back(n)) + } } #[cfg(target_pointer_width = "64")] diff --git a/roaring/tests/iter.rs b/roaring/tests/iter.rs index 86a83245..05591681 100644 --- a/roaring/tests/iter.rs +++ b/roaring/tests/iter.rs @@ -81,6 +81,53 @@ proptest! { } } +proptest! { + #[test] + fn nth(values in btree_set(any::(), ..=10_000), nth in 0..10_005usize) { + let bitmap = RoaringBitmap::from_sorted_iter(values.iter().cloned()).unwrap(); + let mut orig_iter = bitmap.iter().fuse(); + let mut iter = bitmap.iter(); + + for _ in 0..nth { + if orig_iter.next().is_none() { + break; + } + } + let expected = orig_iter.next(); + assert_eq!(expected, iter.nth(nth)); + let expected_next = orig_iter.next(); + assert_eq!(expected_next, iter.next()); + + let mut val_iter = values.into_iter(); + assert_eq!(expected, val_iter.nth(nth)); + assert_eq!(expected_next, val_iter.next()); + } +} + +#[test] +fn huge_nth() { + let bitmap = RoaringBitmap::new(); + let mut iter = bitmap.iter(); + assert_eq!(None, iter.nth(usize::MAX)); +} + +proptest! { + #[test] + fn count(values in btree_set(any::(), ..=10_000), skip in 0..10_005usize) { + let bitmap = RoaringBitmap::from_sorted_iter(values.iter().cloned()).unwrap(); + let mut iter = bitmap.iter(); + + if let Some(n) = skip.checked_sub(1) { + iter.nth(n); + } + let expected_count = values.len().saturating_sub(skip); + let size_hint = iter.size_hint(); + assert_eq!(expected_count, size_hint.0); + assert_eq!(Some(expected_count), size_hint.1); + assert_eq!(expected_count, iter.count()); + } +} + #[test] fn rev_array() { let values = 0..100;