diff --git a/Cargo.toml b/Cargo.toml index 2e370ce..dc00021 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,3 +6,4 @@ description = "Multisets/bags" keywords = ["multiset","bag","data-structure","collection","count"] license = "MIT/Apache-2.0" authors = ["Jake Mitchell "] +edition = "2015" diff --git a/src/multiset.rs b/src/multiset.rs index b93045d..d0dac07 100644 --- a/src/multiset.rs +++ b/src/multiset.rs @@ -14,6 +14,7 @@ use std::collections::HashMap; use std::fmt; use std::hash::Hash; use std::iter::{FromIterator, IntoIterator}; +use std::num::NonZero; use std::ops::{Add, Sub}; /// A hash-based multiset. @@ -22,7 +23,7 @@ pub struct HashMultiSet where K: Eq + Hash, { - elem_counts: HashMap, + elem_counts: HashMap>, size: usize, } @@ -31,28 +32,19 @@ where /// This `struct` is created by the [`iter`] method on [`HashMultiSet`]. #[derive(Clone)] pub struct Iter<'a, K: 'a> { - iter: hash_map::Iter<'a, K, usize>, - duplicate: Option<(&'a K, &'a usize)>, - duplicate_index: usize, + iter: hash_map::Iter<'a, K, NonZero>, + current: Option<(&'a K, NonZero)>, } impl<'a, K> Iterator for Iter<'a, K> { type Item = &'a K; fn next(&mut self) -> Option<&'a K> { - if self.duplicate.is_none() { - self.duplicate = self.iter.next(); - } - if let Some((key, count)) = self.duplicate { - self.duplicate_index += 1; - if self.duplicate_index >= *count { - self.duplicate = None; - self.duplicate_index = 0; - } - Some(key) - } else { - None + let (key, count) = self.current.take().or_else(|| self.iter.next().map(|(key, &count)| (key, count)))?; + if let Some(new_count) = NonZero::new(count.get() - 1) { + self.current = Some((key, new_count)); } + Some(key) } } @@ -94,11 +86,10 @@ where /// } /// assert_eq!(3, multiset.iter().count()); /// ``` - pub fn iter(&self) -> Iter { + pub fn iter(&self) -> Iter<'_, K> { Iter { iter: self.elem_counts.iter(), - duplicate: None, - duplicate_index: 0, + current: None, } } @@ -133,10 +124,10 @@ where /// assert_eq!(set.contains(&1), true); /// assert_eq!(set.contains(&4), false); /// ``` - pub fn contains(&self, value: &Q) -> bool + pub fn contains(&self, value: &Q) -> bool where K: Borrow, - Q: Hash + Eq, + Q: Hash + Eq + ?Sized, { self.elem_counts.contains_key(value) } @@ -186,7 +177,7 @@ where /// assert!(distinct.contains(&2)); /// assert!(!distinct.contains(&3)); /// ``` - pub fn distinct_elements(&self) -> Keys { + pub fn distinct_elements(&self) -> Keys<'_, K, NonZero> { self.elem_counts.keys() } @@ -223,14 +214,15 @@ where /// assert_eq!(3, multiset.count_of(&5)); /// ``` pub fn insert_times(&mut self, val: K, n: usize) { - self.size += n; + self.size = self.size.checked_add(n).expect("count overflow"); + let Some(n) = NonZero::new(n) else { return; }; match self.elem_counts.entry(val) { Entry::Vacant(view) => { view.insert(n); } Entry::Occupied(mut view) => { let v = view.get_mut(); - *v += n; + *v = v.checked_add(n.get()).expect("count overflow"); } } } @@ -280,15 +272,18 @@ where /// assert!(multiset.count_of(&5) == 0); /// ``` pub fn remove_times(&mut self, val: &K, times: usize) -> usize { - if let Some(count) = self.elem_counts.get_mut(val) { - if *count > times { - *count -= times; - self.size -= times; - return times; - } - self.size -= *count; + let Some(actual_count) = self.elem_counts.get_mut(val) else { return 0; }; + let result = if let Some(new_count) = actual_count.get().checked_sub(times).and_then(NonZero::new) { + *actual_count = new_count; + times } - self.elem_counts.remove(val).unwrap_or(0) + else { + let result = actual_count.get(); + self.elem_counts.remove(val); + result + }; + self.size -= result; + result } /// Remove all of an element from the multiset. @@ -308,8 +303,8 @@ where /// assert!(multiset.len() == 0); /// ``` pub fn remove_all(&mut self, val: &K) { - self.size -= self.elem_counts.get(val).unwrap_or(&0); - self.elem_counts.remove(val); + let count = self.elem_counts.remove(val).map(NonZero::get).unwrap_or_default(); + self.size -= count; } /// Counts the occurrences of `val`. @@ -328,7 +323,7 @@ where /// assert_eq!(1, multiset.count_of(&1)); /// ``` pub fn count_of(&self, val: &K) -> usize { - self.elem_counts.get(val).map_or(0, |x| *x) + self.elem_counts.get(val).map(|c| c.get()).unwrap_or_default() } } @@ -358,7 +353,7 @@ where /// ``` fn add(mut self, rhs: HashMultiSet) -> HashMultiSet { for (val, count) in rhs.elem_counts { - self.insert_times(val, count); + self.insert_times(val, count.get()); } self } @@ -391,7 +386,7 @@ where /// ``` fn sub(mut self, rhs: HashMultiSet) -> HashMultiSet { for (val, count) in rhs.elem_counts { - self.remove_times(&val, count); + self.remove_times(&val, count.get()); } self } @@ -537,4 +532,11 @@ mod test_multiset { set.remove(&'d'); assert_eq!(set.len(), 0); } + + #[test] + fn test_insert_times_zero() { + let mut a = HashMultiSet::new(); + a.insert_times(5, 0); + assert!(!a.contains(&5)); + } }