Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1050,10 +1050,16 @@ impl<
Ok(())
}

/// Upserts a placeholder, optionally validating existing resident values
///
/// Returns:
/// - `Ok((token, value))` if a valid resident was found
/// - `Err((placeholder, is_new))` where `is_new` indicates if this is a newly created placeholder
pub fn upsert_placeholder<Q>(
&mut self,
hash: u64,
key: &Q,
validator: &mut impl FnMut(&Val) -> bool,
) -> Result<(Token, &Val), (Plh, bool)>
where
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
Expand All @@ -1063,6 +1069,40 @@ impl<
let (entry, _) = self.entries.get_mut(idx).unwrap();
match entry {
Entry::Resident(resident) => {
if !validator(&resident.value) {
let old_state = resident.state;
let old_key = &resident.key;
let old_value = &resident.value;
let weight = self.weighter.weight(old_key, old_value);

let shared = Plh::new(hash, idx);
*entry = Entry::Placeholder(Placeholder {
key: key.to_owned(),
hot: old_state,
shared: shared.clone(),
});

match old_state {
ResidentState::Hot => {
self.num_hot -= 1;
self.weight_hot -= weight;
if weight != 0 {
self.hot_head = self.entries.unlink(idx);
}
}
ResidentState::Cold => {
self.num_cold -= 1;
self.weight_cold -= weight;
if weight != 0 {
self.cold_head = self.entries.unlink(idx);
}
}
}

record_miss_mut!(self);
return Err((shared, false)); // false = replaced existing
}

if *resident.referenced.get_mut() < MAX_F {
*resident.referenced.get_mut() += 1;
}
Expand Down
35 changes: 32 additions & 3 deletions src/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,14 +435,43 @@ impl<
&'a self,
key: &Q,
) -> Result<Val, PlaceholderGuard<'a, Key, Val, We, B, L>>
where
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
self.get_validate_value_or_guard_async(key, |_| true).await
}

/// Gets an item from the cache with key `key`, applying `validation` to determine whether
/// the value is 'live'.
///
/// If the corresponding value isn't present in the cache or fails validation, this functions
/// returns a guard that can be used to insert the value once it's computed.
/// While the returned guard is alive, other calls with the same key using the
/// `get_value_guard` or `get_or_insert` family of functions will wait until the guard
/// is dropped or the value is inserted.
pub async fn get_validate_value_or_guard_async<'a, Q>(
&'a self,
key: &Q,
mut validation: impl FnMut(&Val) -> bool + Unpin,
) -> Result<Val, PlaceholderGuard<'a, Key, Val, We, B, L>>
where
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
let (shard, hash) = self.shard_for(key).unwrap();
if let Some(v) = shard.read().get(hash, key) {
return Ok(v.clone());

// Try fast path with read lock first
{
let reader = shard.read();
if let Some(v) = reader.get(hash, key) {
if validation(v) {
return Ok(v.clone());
}
// Validation failed, fall through to JoinFuture
}
// No entry found or validation failed, let JoinFuture handle everything
}
JoinFuture::new(&self.lifecycle, shard, hash, key).await

JoinFuture::new(&self.lifecycle, shard, hash, key, validation).await
}

/// Gets or inserts an item in the cache with key `key`.
Expand Down
23 changes: 16 additions & 7 deletions src/sync_placeholder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ impl<
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
let mut shard_guard = shard.write();
let shared = match shard_guard.upsert_placeholder(hash, key) {
let shared = match shard_guard.upsert_placeholder(hash, key, &mut |_| true) {
Ok((_, v)) => return GuardResult::Value(v.clone()),
Err((shared, true)) => {
return GuardResult::Guard(Self::start_loading(lifecycle, shard, shared));
Expand Down Expand Up @@ -413,11 +413,12 @@ impl<Key, Val, We, B, L> std::fmt::Debug for PlaceholderGuard<'_, Key, Val, We,
}

/// Future that results in an Ok(Value) or Err(Guard)
pub struct JoinFuture<'a, 'b, Q: ?Sized, Key, Val, We, B, L> {
pub struct JoinFuture<'a, 'b, Q: ?Sized, Key, Val, We, B, L, F: FnMut(&'a Val) -> bool> {
lifecycle: &'a L,
shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
state: JoinFutureState<'b, Q, Val>,
notified: AtomicBool,
validation: F,
}

enum JoinFutureState<'b, Q: ?Sized, Val> {
Expand All @@ -432,18 +433,22 @@ enum JoinFutureState<'b, Q: ?Sized, Val> {
Done,
}

impl<'a, 'b, Q: ?Sized, Key, Val, We, B, L> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> {
impl<'a, 'b, Q: ?Sized, Key, Val, We, B, L, F: FnMut(&'a Val) -> bool>
JoinFuture<'a, 'b, Q, Key, Val, We, B, L, F>
{
pub fn new(
lifecycle: &'a L,
shard: &'a RwLock<CacheShard<Key, Val, We, B, L, SharedPlaceholder<Val>>>,
hash: u64,
key: &'b Q,
) -> JoinFuture<'a, 'b, Q, Key, Val, We, B, L> {
validation: F,
) -> JoinFuture<'a, 'b, Q, Key, Val, We, B, L, F> {
Self {
lifecycle,
shard,
state: JoinFutureState::Created { hash, key },
notified: Default::default(),
validation,
}
}

Expand Down Expand Up @@ -480,7 +485,10 @@ impl<'a, 'b, Q: ?Sized, Key, Val, We, B, L> JoinFuture<'a, 'b, Q, Key, Val, We,
}
}

impl<Q: ?Sized, Key, Val, We, B, L> Drop for JoinFuture<'_, '_, Q, Key, Val, We, B, L> {
impl<'a, Q: ?Sized, Key, Val, We, B, L, F> Drop for JoinFuture<'a, '_, Q, Key, Val, We, B, L, F>
where
F: FnMut(&'a Val) -> bool,
{
#[inline]
fn drop(&mut self) {
if matches!(self.state, JoinFutureState::Pending { .. }) {
Expand All @@ -497,7 +505,8 @@ impl<
We: Weighter<Key, Val>,
B: BuildHasher,
L: Lifecycle<Key, Val>,
> Future for JoinFuture<'a, '_, Q, Key, Val, We, B, L>
F: FnMut(&Val) -> bool + Unpin,
> Future for JoinFuture<'a, '_, Q, Key, Val, We, B, L, F>
{
type Output = Result<Val, PlaceholderGuard<'a, Key, Val, We, B, L>>;

Expand All @@ -509,7 +518,7 @@ impl<
JoinFutureState::Created { hash, key } => {
debug_assert!(!this.notified.load(Ordering::Acquire));
let mut shard_guard = shard.write();
match shard_guard.upsert_placeholder(*hash, *key) {
match shard_guard.upsert_placeholder(*hash, *key, &mut this.validation) {
Ok((_, v)) => {
this.state = JoinFutureState::Done;
Poll::Ready(Ok(v.clone()))
Expand Down
20 changes: 16 additions & 4 deletions src/unsync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ impl<Key: Eq + Hash, Val, We: Weighter<Key, Val>, B: BuildHasher, L: Lifecycle<K
where
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
let idx = match self.shard.upsert_placeholder(self.shard.hash(key), key) {
let idx = match self
.shard
.upsert_placeholder(self.shard.hash(key), key, &mut |_| true)
{
Ok((idx, _)) => idx,
Err((plh, _)) => {
let v = with()?;
Expand All @@ -275,7 +278,10 @@ impl<Key: Eq + Hash, Val, We: Weighter<Key, Val>, B: BuildHasher, L: Lifecycle<K
where
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
let idx = match self.shard.upsert_placeholder(self.shard.hash(key), key) {
let idx = match self
.shard
.upsert_placeholder(self.shard.hash(key), key, &mut |_| true)
{
Ok((idx, _)) => idx,
Err((plh, _)) => {
let v = with()?;
Expand All @@ -297,7 +303,10 @@ impl<Key: Eq + Hash, Val, We: Weighter<Key, Val>, B: BuildHasher, L: Lifecycle<K
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
// TODO: this could be using a simpler entry API
match self.shard.upsert_placeholder(self.shard.hash(key), key) {
match self
.shard
.upsert_placeholder(self.shard.hash(key), key, &mut |_| true)
{
Ok((_, v)) => unsafe {
// Rustc gets insanely confused about returning from mut borrows
// Safety: v has the same lifetime as self
Expand All @@ -323,7 +332,10 @@ impl<Key: Eq + Hash, Val, We: Weighter<Key, Val>, B: BuildHasher, L: Lifecycle<K
Q: Hash + Equivalent<Key> + ToOwned<Owned = Key> + ?Sized,
{
// TODO: this could be using a simpler entry API
match self.shard.upsert_placeholder(self.shard.hash(key), key) {
match self
.shard
.upsert_placeholder(self.shard.hash(key), key, &mut |_| true)
{
Ok((idx, _)) => Ok(self.shard.peek_token_mut(idx).map(RefMut)),
Err((placeholder, _)) => Err(Guard {
cache: self,
Expand Down