Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/iddqd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@
#![no_std]
#![cfg_attr(doc_cfg, feature(doc_auto_cfg))]
#![warn(missing_docs)]
#![feature(btree_set_entry)]

#[cfg_attr(not(feature = "std"), macro_use)] // for `format!`
extern crate alloc;
Expand Down
192 changes: 84 additions & 108 deletions crates/iddqd/src/support/btree_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,60 +11,12 @@ use alloc::{
vec::Vec,
};
use core::{
cell::Cell,
borrow::Borrow,
cmp::Ordering,
hash::{BuildHasher, Hash},
marker::PhantomData,
};
use equivalent::Comparable;

thread_local! {
/// Stores an external comparator function to provide dynamic scoping.
///
/// std's BTreeMap doesn't allow passing an external comparator, so we make
/// do with this function that's passed in through dynamic scoping.
///
/// This works by:
///
/// * We store an `Index` in the BTreeSet which knows how to call this
/// dynamic comparator.
/// * When we need to compare two `Index` values, we create a CmpDropGuard.
/// This struct is responsible for managing the lifetime of the
/// comparator.
/// * When the CmpDropGuard is dropped (including due to a panic), we reset
/// the comparator to None.
///
/// This is not great! (For one, thread-locals and no-std don't really mix.)
/// Some alternatives:
///
/// * Using `Borrow` as described in
/// https://github.com/sunshowers-code/borrow-complex-key-example. While
/// hacky, this actually works for the find operation. But the insert
/// operation currently requires a concrete `Index`.
///
/// If and when https://github.com/rust-lang/rust/issues/133549 lands,
/// this should become a viable option. Worth looking out for!
///
/// * Using a third-party BTreeSet implementation that allows passing in
/// external comparators. As of 2025-05, there appear to be two options:
///
/// 1. copse (https://docs.rs/copse), which doesn't seem like a good fit
/// here.
/// 2. btree_monstrousity (https://crates.io/crates/btree_monstrousity),
/// which has an API perfect for this but is, uhh, not really
/// production-ready.
///
/// Third-party implementations also run the risk of being relatively
/// untested.
///
/// * Using some other kind of sorted set. We've picked B-trees here as the
/// default choice to balance cache locality, but other options are worth
/// benchmarking. We do need to provide a comparator, though, so radix
/// trees and such are out of the question.
static CMP: Cell<Option<&'static dyn Fn(Index, Index) -> Ordering>>
= const { Cell::new(None) };
}

/// A B-tree-based table with an external comparator.
#[derive(Clone, Debug, Default)]
pub(crate) struct MapBTreeTable {
Expand Down Expand Up @@ -154,10 +106,10 @@ impl MapBTreeTable {
F: Fn(usize) -> K,
{
let f = find_cmp(key, lookup);
let cmp_wrapper =
CmpWrapper { index: Index::SENTINEL, cmp_fn: Some(&f) };

let guard = CmpDropGuard::new(&f);

let ret = match self.items.get(&Index::SENTINEL) {
let ret = match self.items.get(&cmp_wrapper as &dyn CmpKey<_>) {
Some(Index(v)) if *v == Index::SENTINEL_VALUE => {
panic!("internal map shouldn't store sentinel value")
}
Expand All @@ -168,8 +120,6 @@ impl MapBTreeTable {
}
};

// drop(guard) isn't necessary, but we make it explicit
drop(guard);
ret
}

Expand All @@ -180,12 +130,11 @@ impl MapBTreeTable {
F: Fn(usize) -> K,
{
let f = insert_cmp(index, key, lookup);
let guard = CmpDropGuard::new(&f);

self.items.insert(Index::new(index));
let index = Index::new(index);
let cmp_wrapper = CmpWrapper { index, cmp_fn: Some(&f) };

// drop(guard) isn't necessary, but we make it explicit
drop(guard);
self.items
.get_or_insert_with(&cmp_wrapper as &dyn CmpKey<_>, |_| index);
}

pub(crate) fn remove<K, F>(&mut self, index: usize, key: K, lookup: F)
Expand All @@ -194,12 +143,10 @@ impl MapBTreeTable {
K: Ord,
{
let f = insert_cmp(index, &key, lookup);
let guard = CmpDropGuard::new(&f);
let find_cmp =
CmpWrapper { index: Index::new(index), cmp_fn: Some(&f) };

self.items.remove(&Index::new(index));

// drop(guard) isn't necessary, but we make it explicit
drop(guard);
self.items.remove(&find_cmp as &dyn CmpKey<_>);
}

pub(crate) fn iter(&self) -> Iter {
Expand Down Expand Up @@ -317,37 +264,6 @@ where
}
}

struct CmpDropGuard<'a> {
_marker: PhantomData<&'a ()>,
}

impl<'a> CmpDropGuard<'a> {
fn new(f: &'a dyn Fn(Index, Index) -> Ordering) -> Self {
// CMP lasts only as long as this function and is immediately reset to
// None once this scope is left.
let ret = Self { _marker: PhantomData };

let as_static = unsafe {
// SAFETY: This is safe because we are not storing the reference
// anywhere, and it is only used for the lifetime of this
// CmpDropGuard.
std::mem::transmute::<
&'a dyn Fn(Index, Index) -> Ordering,
&'static dyn Fn(Index, Index) -> Ordering,
>(f)
};
CMP.set(Some(as_static));

ret
}
}

impl Drop for CmpDropGuard<'_> {
fn drop(&mut self) {
CMP.set(None);
}
}

#[derive(Clone, Copy, Debug)]
struct Index(usize);

Expand All @@ -373,26 +289,16 @@ impl PartialEq for Index {
return self.0 == other.0;
}

// If any of the two indexes is the sentinel, we're required to perform
// a lookup.
CMP.with(|cmp| {
let cmp = cmp.get().expect("cmp should be set");
cmp(*self, *other) == Ordering::Equal
})
panic!("we should never call PartialEq on indexes");
}
}

impl Eq for Index {}

impl Ord for Index {
#[inline]
fn cmp(&self, other: &Self) -> Ordering {
// Ord should only be called if we're doing lookups within the table,
// which should have set the thread local.
CMP.with(|cmp| {
let cmp = cmp.get().expect("cmp should be set");
cmp(*self, *other)
})
fn cmp(&self, _other: &Self) -> Ordering {
panic!("we should never call Ord on indexes");
}
}

Expand All @@ -402,3 +308,73 @@ impl PartialOrd for Index {
Some(self.cmp(other))
}
}

struct CmpWrapper<'a, F> {
index: Index,
cmp_fn: Option<&'a F>,
}

impl<F> Clone for CmpWrapper<'_, F> {
fn clone(&self) -> Self {
Self { index: self.index, cmp_fn: self.cmp_fn }
}
}

impl<F> Copy for CmpWrapper<'_, F> {}

trait CmpKey<F> {
fn key(&self) -> CmpWrapper<'_, F>;
}

impl<F> CmpKey<F> for Index {
fn key(&self) -> CmpWrapper<'_, F> {
CmpWrapper { index: *self, cmp_fn: None }
}
}

impl<'a, F> CmpKey<F> for CmpWrapper<'a, F> {
fn key(&self) -> CmpWrapper<'_, F> {
*self
}
}

impl<'a, F> Borrow<dyn CmpKey<F> + 'a> for Index {
fn borrow(&self) -> &(dyn CmpKey<F> + 'a) {
self
}
}

impl<'a, F: Fn(Index, Index) -> Ordering> PartialEq for (dyn CmpKey<F> + 'a) {
fn eq(&self, other: &Self) -> bool {
let key = self.key();
let other_key = other.key();
// At least one of the cmp fns must be set.
let cmp = key
.cmp_fn
.or_else(|| other_key.cmp_fn)
.expect("at least one key must be set");
cmp(key.index, other_key.index) == Ordering::Equal
}
}

impl<'a, F: Fn(Index, Index) -> Ordering> Eq for (dyn CmpKey<F> + 'a) {}

impl<'a, F: Fn(Index, Index) -> Ordering> Ord for (dyn CmpKey<F> + 'a) {
fn cmp(&self, other: &Self) -> Ordering {
let key = self.key();
let other_key = other.key();
// At least one of the cmp fns must be set.
let cmp = key
.cmp_fn
.or_else(|| other_key.cmp_fn)
.expect("at least one key must be set");
cmp(key.index, other_key.index)
}
}

impl<'a, F: Fn(Index, Index) -> Ordering> PartialOrd for (dyn CmpKey<F> + 'a) {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
Loading