Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ description = "Multisets/bags"
keywords = ["multiset","bag","data-structure","collection","count"]
license = "MIT/Apache-2.0"
authors = ["Jake Mitchell <jacob.d.mitchell@gmail.com>"]
edition = "2015"
76 changes: 39 additions & 37 deletions src/multiset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::collections::HashMap;
use std::fmt;
use std::hash::Hash;
use std::iter::{FromIterator, IntoIterator};
use std::num::NonZero;
use std::ops::{Add, Sub};

/// A hash-based multiset.
Expand All @@ -22,7 +23,7 @@ pub struct HashMultiSet<K>
where
K: Eq + Hash,
{
elem_counts: HashMap<K, usize>,
elem_counts: HashMap<K, NonZero<usize>>,
size: usize,
}

Expand All @@ -31,28 +32,19 @@ where
/// This `struct` is created by the [`iter`] method on [`HashMultiSet`].
#[derive(Clone)]
pub struct Iter<'a, K: 'a> {
iter: hash_map::Iter<'a, K, usize>,
duplicate: Option<(&'a K, &'a usize)>,
duplicate_index: usize,
iter: hash_map::Iter<'a, K, NonZero<usize>>,
current: Option<(&'a K, NonZero<usize>)>,
}

impl<'a, K> Iterator for Iter<'a, K> {
type Item = &'a K;

fn next(&mut self) -> Option<&'a K> {
if self.duplicate.is_none() {
self.duplicate = self.iter.next();
}
if let Some((key, count)) = self.duplicate {
self.duplicate_index += 1;
if self.duplicate_index >= *count {
self.duplicate = None;
self.duplicate_index = 0;
}
Some(key)
} else {
None
let (key, count) = self.current.take().or_else(|| self.iter.next().map(|(key, &count)| (key, count)))?;
if let Some(new_count) = NonZero::new(count.get() - 1) {
self.current = Some((key, new_count));
}
Some(key)
}
}

Expand Down Expand Up @@ -94,11 +86,10 @@ where
/// }
/// assert_eq!(3, multiset.iter().count());
/// ```
pub fn iter(&self) -> Iter<K> {
pub fn iter(&self) -> Iter<'_, K> {
Iter {
iter: self.elem_counts.iter(),
duplicate: None,
duplicate_index: 0,
current: None,
}
}

Expand Down Expand Up @@ -133,10 +124,10 @@ where
/// assert_eq!(set.contains(&1), true);
/// assert_eq!(set.contains(&4), false);
/// ```
pub fn contains<Q: ?Sized>(&self, value: &Q) -> bool
pub fn contains<Q>(&self, value: &Q) -> bool
where
K: Borrow<Q>,
Q: Hash + Eq,
Q: Hash + Eq + ?Sized,
{
self.elem_counts.contains_key(value)
}
Expand Down Expand Up @@ -186,7 +177,7 @@ where
/// assert!(distinct.contains(&2));
/// assert!(!distinct.contains(&3));
/// ```
pub fn distinct_elements(&self) -> Keys<K, usize> {
pub fn distinct_elements(&self) -> Keys<'_, K, NonZero<usize>> {
self.elem_counts.keys()
}

Expand Down Expand Up @@ -223,14 +214,15 @@ where
/// assert_eq!(3, multiset.count_of(&5));
/// ```
pub fn insert_times(&mut self, val: K, n: usize) {
self.size += n;
self.size = self.size.checked_add(n).expect("count overflow");
let Some(n) = NonZero::new(n) else { return; };
match self.elem_counts.entry(val) {
Entry::Vacant(view) => {
view.insert(n);
}
Entry::Occupied(mut view) => {
let v = view.get_mut();
*v += n;
*v = v.checked_add(n.get()).expect("count overflow");
}
}
}
Expand Down Expand Up @@ -280,15 +272,18 @@ where
/// assert!(multiset.count_of(&5) == 0);
/// ```
pub fn remove_times(&mut self, val: &K, times: usize) -> usize {
if let Some(count) = self.elem_counts.get_mut(val) {
if *count > times {
*count -= times;
self.size -= times;
return times;
}
self.size -= *count;
let Some(actual_count) = self.elem_counts.get_mut(val) else { return 0; };
let result = if let Some(new_count) = actual_count.get().checked_sub(times).and_then(NonZero::new) {
*actual_count = new_count;
times
}
self.elem_counts.remove(val).unwrap_or(0)
else {
let result = actual_count.get();
self.elem_counts.remove(val);
result
};
self.size -= result;
result
}

/// Remove all of an element from the multiset.
Expand All @@ -308,8 +303,8 @@ where
/// assert!(multiset.len() == 0);
/// ```
pub fn remove_all(&mut self, val: &K) {
self.size -= self.elem_counts.get(val).unwrap_or(&0);
self.elem_counts.remove(val);
let count = self.elem_counts.remove(val).map(NonZero::get).unwrap_or_default();
self.size -= count;
}

/// Counts the occurrences of `val`.
Expand All @@ -328,7 +323,7 @@ where
/// assert_eq!(1, multiset.count_of(&1));
/// ```
pub fn count_of(&self, val: &K) -> usize {
self.elem_counts.get(val).map_or(0, |x| *x)
self.elem_counts.get(val).map(|c| c.get()).unwrap_or_default()
}
}

Expand Down Expand Up @@ -358,7 +353,7 @@ where
/// ```
fn add(mut self, rhs: HashMultiSet<T>) -> HashMultiSet<T> {
for (val, count) in rhs.elem_counts {
self.insert_times(val, count);
self.insert_times(val, count.get());
}
self
}
Expand Down Expand Up @@ -391,7 +386,7 @@ where
/// ```
fn sub(mut self, rhs: HashMultiSet<T>) -> HashMultiSet<T> {
for (val, count) in rhs.elem_counts {
self.remove_times(&val, count);
self.remove_times(&val, count.get());
}
self
}
Expand Down Expand Up @@ -537,4 +532,11 @@ mod test_multiset {
set.remove(&'d');
assert_eq!(set.len(), 0);
}

#[test]
fn test_insert_times_zero() {
let mut a = HashMultiSet::new();
a.insert_times(5, 0);
assert!(!a.contains(&5));
}
}