From b17b4fd954acbaeb55c4f1db57ac22a88b075a93 Mon Sep 17 00:00:00 2001 From: ecoricemon Date: Mon, 23 Mar 2026 02:45:17 +0900 Subject: [PATCH 1/2] feat: Implement SLG Added SLG path, which is selected for rucursive clauses. --- crates/logic-eval/README.md | 32 +- crates/logic-eval/src/lib.rs | 9 +- crates/logic-eval/src/parse/common.rs | 68 +- crates/logic-eval/src/parse/repr.rs | 158 ++++- crates/logic-eval/src/parse/text.rs | 8 +- crates/logic-eval/src/prove/canonical.rs | 60 ++ crates/logic-eval/src/prove/db.rs | 186 +++++- crates/logic-eval/src/prove/mod.rs | 2 + crates/logic-eval/src/prove/prover.rs | 787 +++++++++++++++-------- crates/logic-eval/src/prove/repr.rs | 482 ++++++++++---- crates/logic-eval/src/prove/table.rs | 215 +++++++ 11 files changed, 1470 insertions(+), 537 deletions(-) create mode 100644 crates/logic-eval/src/prove/canonical.rs create mode 100644 crates/logic-eval/src/prove/table.rs diff --git a/crates/logic-eval/README.md b/crates/logic-eval/README.md index 3a036b1..e6b0a2c 100644 --- a/crates/logic-eval/README.md +++ b/crates/logic-eval/README.md @@ -11,30 +11,45 @@ [codecov-badge]: https://codecov.io/gh/ecoricemon/logic-eval/graph/badge.svg?flag=logic-eval [codecov-url]: https://app.codecov.io/gh/ecoricemon/logic-eval?flags%5B0%5D=logic-eval -A prolog-like logic evaluator. +```text ++------------------+ +| Feel(fresh) :- | +| Sleep(well), | +| Sun(shine), | +| Air(cool). | ++------------------+ +``` + +`logic-eval` is a Prolog-like logic evaluation library for Rust. + +## Features + +- `SLG resolution`: handles recursive queries with tabling. +- `Custom type support`: use `&str`, interned strings, or your own `Atom` type. +- `Parsing`: parse facts, rules, and queries from a Prolog-like text syntax with `parse_str`. +- `Basic logical operators`: supports NOT, AND, and OR in rule bodies. + +## Examples -## Example +### Parse text and query a database ```rust use logic_eval::{Database, StrInterner, parse_str}; -// Creates a DB. let mut db = Database::new(); let interner = StrInterner::new(); -// Initializes the DB with a little bit of logic. let dataset = " child(a, b). child(b, c). + child(c, d). descend($X, $Y) :- child($X, $Y). descend($X, $Z) :- child($X, $Y), descend($Y, $Z). "; db.insert_dataset(parse_str(dataset, &interner).unwrap()); db.commit(); -// Queries the DB. -let query = "descend($X, $Y)."; -let mut cx = db.query(parse_str(query, &interner).unwrap()); +let mut cx = db.query(parse_str("descend($X, $Y).", &interner).unwrap()); let mut answer = Vec::new(); while let Some(eval) = cx.prove_next() { @@ -45,6 +60,9 @@ while let Some(eval) = cx.prove_next() { assert_eq!(answer, [ "$X = a, $Y = b", "$X = b, $Y = c", + "$X = c, $Y = d", "$X = a, $Y = c", + "$X = b, $Y = d", + "$X = a, $Y = d", ]); ``` diff --git a/crates/logic-eval/src/lib.rs b/crates/logic-eval/src/lib.rs index 42bdeb3..d71b99c 100644 --- a/crates/logic-eval/src/lib.rs +++ b/crates/logic-eval/src/lib.rs @@ -6,7 +6,7 @@ mod prove; // === Re-exports === pub use parse::{ - common::{Intern, InternedStr, StrCanonicalizer, StrInterner}, + common::{Intern, InternedStr, StrInterner}, inner::VAR_PREFIX, inner::{parse_str, Parse}, repr::{Clause, ClauseDataset, Expr, Predicate, Term}, @@ -37,7 +37,10 @@ mod intern_alias { pub(crate) type ClauseDatasetIn<'int, Int> = parse::repr::ClauseDataset>; } -// === Hash map and set used within this crate === +pub(crate) type Map = fxhash::FxHashMap; +pub(crate) type IndexMap = indexmap::IndexMap; +pub(crate) type IndexSet = indexmap::IndexSet; +pub(crate) type PassThroughIndexMap = indexmap::IndexMap; use std::{ error::Error as StdError, @@ -45,8 +48,6 @@ use std::{ result::Result as StdResult, }; -pub(crate) type Map = fxhash::FxHashMap; - #[derive(Default, Clone, Copy)] struct PassThroughHasher { hash: u64, diff --git a/crates/logic-eval/src/parse/common.rs b/crates/logic-eval/src/parse/common.rs index b45b867..370aa06 100644 --- a/crates/logic-eval/src/parse/common.rs +++ b/crates/logic-eval/src/parse/common.rs @@ -1,6 +1,5 @@ -use crate::{Atom, Clause, Expr, Term}; +use crate::Atom; use core::fmt::{self, Display}; -use smallvec::SmallVec; pub trait Intern { type Interned<'a>: Atom @@ -56,68 +55,3 @@ impl Intern for any_intern::Interner { pub type StrInterner = any_intern::DroplessInterner; pub type InternedStr<'int> = any_intern::Interned<'int, str>; - -pub struct StrCanonicalizer<'int> { - interner: &'int StrInterner, -} - -impl<'int> StrCanonicalizer<'int> { - pub fn new(interner: &'int StrInterner) -> Self { - Self { interner } - } - - pub fn canonicalize(&self, clause: Clause>) -> Clause> { - let mut vars = SmallVec::new(); - find_var_in_clause(&clause, &mut vars); - - let mut clause = clause; - clause.replace_term(&mut |term| { - if !term.args.is_empty() { - return None; - } - vars.iter().enumerate().find_map(|(i, var)| { - if &term.functor == var { - Some(Term { - functor: self.interner.intern_formatted_str(&i, i % 10 + 1).unwrap(), - args: Vec::new(), - }) - } else { - None - } - }) - }); - - return clause; - - // === Internal helper functions === - - fn find_var_in_clause(clause: &Clause, vars: &mut SmallVec<[T; 4]>) { - find_var_in_term(&clause.head, vars); - if let Some(body) = &clause.body { - find_var_in_expr(body, vars); - } - } - - fn find_var_in_expr(expr: &Expr, vars: &mut SmallVec<[T; 4]>) { - match expr { - Expr::Term(term) => find_var_in_term(term, vars), - Expr::Not(expr) => find_var_in_expr(expr, vars), - Expr::And(expr) | Expr::Or(expr) => { - for inner_expr in expr.iter() { - find_var_in_expr(inner_expr, vars); - } - } - } - } - - fn find_var_in_term(term: &Term, vars: &mut SmallVec<[T; 4]>) { - if term.functor.is_variable() { - vars.push(term.functor.clone()); - } else { - for arg in &term.args { - find_var_in_term(arg, vars); - } - } - } - } -} diff --git a/crates/logic-eval/src/parse/repr.rs b/crates/logic-eval/src/parse/repr.rs index eb7ac85..5151320 100644 --- a/crates/logic-eval/src/parse/repr.rs +++ b/crates/logic-eval/src/parse/repr.rs @@ -1,5 +1,7 @@ -use super::text::Name; -use crate::Atom; +use crate::{ + prove::{canonical as canon, prover::Integer}, + Atom, +}; use std::{ fmt::{self, Debug, Display, Write}, ops, @@ -62,6 +64,44 @@ impl Clause { } } +impl Clause { + /// Returns true if the clause needs SLG resolution (tabling). + /// + /// If a clause has left or mid recursion, it must be handled by tabling. + /// + /// # Examples + /// foo(X, Y) :- foo(A, B) ... // left recursion + /// foo(X, Y) :- ... foo(A, B) ... // mid recursion + pub fn needs_tabling(&self) -> bool { + return if let Some(body) = &self.body { + let mut head = self.head.clone(); + let mut body = body.clone(); + canon::canonicalize_term(&mut head); + canon::canonicalize_expr_on_term(&mut body); + helper(&body.distribute_not(), &head) + } else { + false + }; + + // === Internal helper functions === + + fn helper(expr: &Expr, head: &Term) -> bool { + match expr { + Expr::Term(term) => term == head, + Expr::Not(arg) => helper(arg, head), + Expr::And(args) => { + if let Some((last, first)) = args.split_last() { + first.iter().any(|arg| helper(arg, head)) || helper(last, head) + } else { + false + } + } + Expr::Or(args) => args.iter().any(|arg| helper(arg, head)), + } + } + } +} + impl Display for Clause { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.head.fmt(f)?; @@ -127,7 +167,7 @@ impl Term { } } -impl Term> { +impl Term { pub fn is_variable(&self) -> bool { let is_variable = self.functor.is_variable(); @@ -146,6 +186,23 @@ impl Term> { self.args.iter().any(|arg| arg.contains_variable()) } + + pub fn replace_variables(&mut self, mut f: F) { + fn helper(term: &mut Term, f: &mut F) + where + T: Atom, + F: FnMut(&mut T), + { + if term.is_variable() { + f(&mut term.functor); + } else { + for arg in &mut term.args { + helper(arg, f); + } + } + } + helper(self, &mut f) + } } impl Display for Term { @@ -190,20 +247,20 @@ impl Expr { Self::Not(Box::new(expr)) } - pub fn expr_and>>(elems: I) -> Self { - Self::And(elems.into_iter().collect()) + pub fn expr_and>>(args: I) -> Self { + Self::And(args.into_iter().collect()) } - pub fn expr_or>>(elems: I) -> Self { - Self::Or(elems.into_iter().collect()) + pub fn expr_or>>(args: I) -> Self { + Self::Or(args.into_iter().collect()) } pub fn map U>(self, f: &mut F) -> Expr { match self { - Self::Term(v) => Expr::Term(v.map(f)), - Self::Not(v) => Expr::Not(Box::new(v.map(f))), - Self::And(v) => Expr::And(v.into_iter().map(|expr| expr.map(f)).collect()), - Self::Or(v) => Expr::Or(v.into_iter().map(|expr| expr.map(f)).collect()), + Self::Term(term) => Expr::Term(term.map(f)), + Self::Not(arg) => Expr::Not(Box::new(arg.map(f))), + Self::And(args) => Expr::And(args.into_iter().map(|arg| arg.map(f)).collect()), + Self::Or(args) => Expr::Or(args.into_iter().map(|arg| arg.map(f)).collect()), } } @@ -225,18 +282,51 @@ impl Expr { } } +impl Expr { + pub fn contains_term(&self, term: &Term) -> bool { + match self { + Self::Term(t) => t == term, + Self::Not(arg) => arg.contains_term(term), + Self::And(args) | Self::Or(args) => args.iter().any(|arg| arg.contains_term(term)), + } + } + + /// e.g. ¬(A ∧ (B ∨ C)) -> ¬A ∨ (¬B ∧ ¬C) + pub fn distribute_not(self) -> Self { + match self { + Self::Term(term) => Self::Term(term), + Self::Not(expr) => match *expr { + Self::Term(term) => Self::Not(Box::new(Self::Term(term))), + Self::Not(inner) => inner.distribute_not(), + Self::And(args) => Self::Or( + args.into_iter() + .map(|arg| Self::Not(Box::new(arg)).distribute_not()) + .collect(), + ), + Self::Or(args) => Self::And( + args.into_iter() + .map(|arg| Self::Not(Box::new(arg)).distribute_not()) + .collect(), + ), + }, + Self::And(args) => Self::And(args.into_iter().map(Self::distribute_not).collect()), + Self::Or(args) => Self::Or(args.into_iter().map(Self::distribute_not).collect()), + } + } +} + impl Display for Expr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Term(term) => term.fmt(f)?, - Self::Not(inner) => { + Self::Not(arg) => { f.write_str("\\+ ")?; - if matches!(**inner, Self::And(_) | Self::Or(_)) { + if matches!(**arg, Self::And(_) | Self::Or(_)) { f.write_char('(')?; - inner.fmt(f)?; + arg.fmt(f)?; f.write_char(')')?; } else { - inner.fmt(f)?; + arg.fmt(f)?; } } Self::And(args) => { @@ -266,8 +356,44 @@ impl Display for Expr { } } -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Predicate { pub functor: T, pub arity: u32, } + +#[cfg(test)] +mod tests { + use super::{Expr, Term}; + + #[test] + fn distribute_not_applies_de_morgan() { + let expr = Expr::expr_not(Expr::expr_and([ + Expr::term_atom("a"), + Expr::expr_or([Expr::term_atom("b"), Expr::term_atom("c")]), + ])); + + let expected = Expr::expr_or([ + Expr::expr_not(Expr::term_atom("a")), + Expr::expr_and([ + Expr::expr_not(Expr::term_atom("b")), + Expr::expr_not(Expr::term_atom("c")), + ]), + ]); + + assert_eq!(expr.distribute_not(), expected); + } + + #[test] + fn distribute_not_removes_double_negation() { + let expr = Expr::expr_not(Expr::expr_not(Expr::term(Term::compound( + "f", + [Term::atom("x")], + )))); + + assert_eq!( + expr.distribute_not(), + Expr::term(Term::compound("f", [Term::atom("x")])) + ); + } +} diff --git a/crates/logic-eval/src/parse/text.rs b/crates/logic-eval/src/parse/text.rs index a81c78f..ef78622 100644 --- a/crates/logic-eval/src/parse/text.rs +++ b/crates/logic-eval/src/parse/text.rs @@ -3,7 +3,7 @@ use super::{ CloseParenToken, CommaToken, DotToken, HornToken, Ident, NegationToken, OpenParenToken, Parse, ParseBuffer, }; -use crate::{Atom, ClauseDatasetIn, ClauseIn, Error, ExprIn, Intern, NameIn, Result, TermIn}; +use crate::{ClauseDatasetIn, ClauseIn, Error, ExprIn, Intern, NameIn, Result, TermIn}; use std::{ borrow::Borrow, fmt::{self, Debug, Display}, @@ -228,12 +228,6 @@ impl Name<()> { } } -impl Name { - pub(crate) fn is_variable(&self) -> bool { - self.0.is_variable() - } -} - impl<'int, Int: Intern> Parse<'int, Int> for NameIn<'int, Int> { fn parse(buf: &mut ParseBuffer<'_>, interner: &'int Int) -> Result { let ident = buf.parse::(interner)?; diff --git a/crates/logic-eval/src/prove/canonical.rs b/crates/logic-eval/src/prove/canonical.rs new file mode 100644 index 0000000..c96d0f8 --- /dev/null +++ b/crates/logic-eval/src/prove/canonical.rs @@ -0,0 +1,60 @@ +use crate::{ + prove::{ + prover::Integer, + repr::{TermId, TermStorage, TermViewMut}, + }, + Atom, Expr, Map, Term, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) struct CanonicalTermId(TermId); + +pub(crate) fn canonicalize_term_id(stor: &mut TermStorage, id: TermId) -> CanonicalTermId { + let mut view = stor.get_term_mut(id); + canonicalize_term_view(&mut view); + CanonicalTermId(view.id()) +} + +/// e.g. f(X, Y, X) -> f(0, 1, 0) +pub(crate) fn canonicalize_term(term: &mut Term) { + let mut c = canonicalizer(); + term.replace_variables(|functor| *functor = c(*functor)); +} + +/// Applies [`canonicalize_term`] on each term without crossing term boundaries. +/// +/// e.g. f(X), g(Y, X) -> f(0), g(0, 1) (not f(0), g(1, 0)) +pub(crate) fn canonicalize_expr_on_term(expr: &mut Expr) { + match expr { + Expr::Term(term) => canonicalize_term(term), + Expr::Not(arg) => canonicalize_expr_on_term(arg), + Expr::And(args) | Expr::Or(args) => { + for arg in args { + canonicalize_expr_on_term(arg); + } + } + } +} + +pub(crate) fn canonicalize_term_view(view: &mut TermViewMut<'_, Integer>) { + let mut c = canonicalizer(); + view.replace_with(|functor| { + if functor.is_variable() { + Some(c(*functor)) + } else { + None + } + }); +} + +fn canonicalizer() -> impl FnMut(Integer) -> Integer { + let mut map = Map::default(); + move |functor: Integer| { + if functor.is_variable() { + let next_int = map.len() as u32; + *map.entry(functor).or_insert(Integer::variable(next_int)) + } else { + functor + } + } +} diff --git a/crates/logic-eval/src/prove/db.rs b/crates/logic-eval/src/prove/db.rs index 67f22f6..b92762c 100644 --- a/crates/logic-eval/src/prove/db.rs +++ b/crates/logic-eval/src/prove/db.rs @@ -11,20 +11,22 @@ use crate::{ VAR_PREFIX, }, prove::repr::{ExprKind, ExprView, TermView, TermViewIter}, - Atom, Map, + Atom, IndexMap, IndexSet, Map, }; use core::{ fmt::{self, Debug, Display, Write}, iter::FusedIterator, }; -use indexmap::{IndexMap, IndexSet}; pub struct Database { /// Clause id dataset. clauses: IndexMap, Vec>, - /// We do not allow duplicated clauses in the dataset. - clause_set: IndexSet>, + /// Clauses that should be handled by tabling. + table_clauses: IndexSet>, + + /// We do not allow duplicate clauses in the dataset. + dup_checker: DuplicateClauseChecker, /// Term and expression storage. stor: TermStorage, @@ -48,7 +50,8 @@ impl Database { pub fn new() -> Self { Self { clauses: IndexMap::default(), - clause_set: IndexSet::default(), + table_clauses: IndexSet::default(), + dup_checker: DuplicateClauseChecker::default(), stor: TermStorage::new(), prover: Prover::new(), nimap: NameIntMap::new(), @@ -59,7 +62,7 @@ impl Database { pub fn terms(&self) -> NamedTermViewIter<'_, T> { NamedTermViewIter { term_iter: self.stor.terms.terms(), - int2name: &self.nimap.int2name, + nimap: &self.nimap, } } @@ -67,7 +70,7 @@ impl Database { ClauseIter { clauses: &self.clauses, stor: &self.stor, - int2name: &self.nimap.int2name, + nimap: &self.nimap, i: 0, j: 0, } @@ -79,19 +82,25 @@ impl Database { } } + /// Inserts the given clause to the DB. pub fn insert_clause(&mut self, clause: Clause) { // Saves current state. We will revert DB when the change is not committed. if self.revert_point.is_none() { self.revert_point = Some(self.state()); } + let clause = clause.map(&mut |t| self.nimap.name_to_int(t)); + + // Records whether the clause needs tabling. + if clause.needs_tabling() { + self.table_clauses.insert(clause.head.predicate()); + } + // If the DB already contains the given clause, then returns. - if !self.clause_set.insert(clause.clone()) { + if !self.dup_checker.insert(clause.clone()) { return; } - let clause = clause.map(&mut |t| self.nimap.name_to_int(t)); - let key = clause.head.predicate(); let value = ClauseId { head: self.stor.insert_term(clause.head), @@ -114,8 +123,13 @@ impl Database { self.revert(revert_point); } - self.prover - .prove(expr, &self.clauses, &mut self.stor, &mut self.nimap) + self.prover.prove( + expr, + &self.clauses, + &self.table_clauses, + &mut self.stor, + &mut self.nimap, + ) } pub fn commit(&mut self) { @@ -135,7 +149,7 @@ impl Database { let mut conv_map = ConversionMap { int_to_str: Map::default(), sanitized_to_suffix: Map::default(), - int2name: &self.nimap.int2name, + nimap: &self.nimap, sanitizer: sanitize, }; @@ -163,7 +177,7 @@ impl Database { int_to_str: Map, // e.g. 0 -> No suffix, 1 -> _1, 2 -> _2, ... sanitized_to_suffix: Map<&'a str, u32>, - int2name: &'a IndexMap, + nimap: &'a NameIntMap, sanitizer: F, } @@ -174,7 +188,7 @@ impl Database { { fn int_to_str(&mut self, int: Integer) -> &str { self.int_to_str.entry(int).or_insert_with(|| { - let name = self.int2name.get(&int).unwrap(); + let name = self.nimap.get_name(&int).unwrap(); let name: &str = name.as_ref(); let mut is_var = false; @@ -311,7 +325,7 @@ impl Database { for (i, len) in clauses_len.into_iter().enumerate() { self.clauses[i].truncate(len); } - self.clause_set.truncate(clause_set_len); + self.dup_checker.truncate(clause_set_len); self.stor.truncate(stor_len); self.nimap.revert(nimap_state); // `self.prover: Prover` does not store any persistent data. @@ -320,7 +334,7 @@ impl Database { fn state(&self) -> DatabaseState { DatabaseState { clauses_len: self.clauses.values().map(|v| v.len()).collect(), - clause_set_len: self.clause_set.len(), + clause_set_len: self.dup_checker.len(), stor_len: self.stor.len(), nimap_state: self.nimap.state(), } @@ -337,7 +351,7 @@ impl Debug for Database { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Database") .field("clauses", &self.clauses) - .field("clause_set", &self.clause_set) + .field("dup_checker", &self.dup_checker) .field("stor", &self.stor) .field("nimap", &self.nimap) .field("revert_point", &self.revert_point) @@ -353,6 +367,45 @@ struct DatabaseState { nimap_state: NameIntMapState, } +#[derive(Debug, Default)] +struct DuplicateClauseChecker { + seen: IndexSet>, + + /// Temporary buffer for granting [`Integer`] to variables. + vars: Vec, +} + +impl DuplicateClauseChecker { + /// Returns true if the given clause is new, has not been seen before. + fn insert(&mut self, clause: Clause) -> bool { + let canonical_clause = clause.map(&mut |t| { + if !t.is_variable() { + t + } else { + if let Some(found) = self.vars.iter().find(|&&var| var == t) { + *found + } else { + let next_int = self.vars.len() as u32; + let int = Integer::variable(next_int); + self.vars.push(int); + int + } + } + }); + let is_new = self.seen.insert(canonical_clause); + self.vars.clear(); + is_new + } + + fn len(&self) -> usize { + self.seen.len() + } + + fn truncate(&mut self, len: usize) { + self.seen.truncate(len); + } +} + /// Turns variables into `_$0`, `_$1`, and so on using the given canonical_var function. /// /// Returns `None` if `canonical_var` is `None` (i.e. deduplication disabled). @@ -397,9 +450,9 @@ fn _convert_var_into_num( fn find_var_in_expr(expr: &Expr) -> Option { match expr { - Expr::Term(t) => find_var_in_term(t), - Expr::Not(e) => find_var_in_expr(e), - Expr::And(v) | Expr::Or(v) => v.iter().find_map(find_var_in_expr), + Expr::Term(term) => find_var_in_term(term), + Expr::Not(arg) => find_var_in_expr(arg), + Expr::And(args) | Expr::Or(args) => args.iter().find_map(find_var_in_expr), } } @@ -416,7 +469,7 @@ fn _convert_var_into_num( pub struct ClauseIter<'a, T> { clauses: &'a IndexMap, Vec>, stor: &'a TermStorage, - int2name: &'a IndexMap, + nimap: &'a NameIntMap, i: usize, j: usize, } @@ -440,7 +493,7 @@ impl<'a, T> Iterator for ClauseIter<'a, T> { Some(ClauseRef { id, stor: self.stor, - int2name: self.int2name, + nimap: self.nimap, }) } } @@ -450,19 +503,19 @@ impl FusedIterator for ClauseIter<'_, T> {} pub struct ClauseRef<'a, T> { id: ClauseId, stor: &'a TermStorage, - int2name: &'a IndexMap, + nimap: &'a NameIntMap, } impl<'a, T: Atom> ClauseRef<'a, T> { pub fn head(&self) -> NamedTermView<'a, T> { let head = self.stor.get_term(self.id.head); - NamedTermView::new(head, self.int2name) + NamedTermView::new(head, self.nimap) } pub fn body(&self) -> Option> { self.id.body.map(|id| { let body = self.stor.get_expr(id); - NamedExprView::new(body, self.int2name) + NamedExprView::new(body, self.nimap) }) } } @@ -485,11 +538,11 @@ impl Debug for ClauseRef<'_, T> { let mut d = f.debug_struct("Clause"); let head = self.stor.get_term(self.id.head); - d.field("head", &NamedTermView::new(head, self.int2name)); + d.field("head", &NamedTermView::new(head, self.nimap)); if let Some(body) = self.id.body { let body = self.stor.get_expr(body); - d.field("body", &NamedExprView::new(body, self.int2name)); + d.field("body", &NamedExprView::new(body, self.nimap)); } d.finish() @@ -498,7 +551,7 @@ impl Debug for ClauseRef<'_, T> { pub struct NamedTermViewIter<'a, T> { term_iter: TermViewIter<'a, Integer>, - int2name: &'a IndexMap, + nimap: &'a NameIntMap, } impl<'a, T: Atom> Iterator for NamedTermViewIter<'a, T> { @@ -507,7 +560,7 @@ impl<'a, T: Atom> Iterator for NamedTermViewIter<'a, T> { fn next(&mut self) -> Option { self.term_iter .next() - .map(|view| NamedTermView::new(view, self.int2name)) + .map(|view| NamedTermView::new(view, self.nimap)) } } @@ -719,6 +772,79 @@ mod str_atom_tests { assert_eq!(answer, expected); } + // SLG resolution (tabling) is required to pass this test. + #[test] + fn test_mid_recursion() { + let mut db = Database::new(); + let interner = Interner::new(); + + insert_dataset( + &mut db, + &interner, + r" + edge(a, b). + edge(b, c). + edge(c, a). + path($X, $Y) :- edge($X, $Z), path($Z, $W), edge($W, $Y). + path($X, $Y) :- edge($X, $Y). + ", + ); + + let query: Expr<'_> = parse::parse_str("path($X, $Y).", &interner).unwrap(); + let mut answer = collect_answer(db.query(query)); + + let mut expected = [ + ["$X = a", "$Y = a"], + ["$X = a", "$Y = b"], + ["$X = a", "$Y = c"], + ["$X = b", "$Y = a"], + ["$X = b", "$Y = b"], + ["$X = b", "$Y = c"], + ["$X = c", "$Y = a"], + ["$X = c", "$Y = b"], + ["$X = c", "$Y = c"], + ]; + + answer.sort_unstable(); + expected.sort_unstable(); + assert_eq!(answer, expected); + } + + // SLG resolution (tabling) is required to pass this test. + #[test] + fn test_left_recursion() { + let mut db = Database::new(); + let interner = Interner::new(); + + insert_dataset( + &mut db, + &interner, + r" + parent(a, b). + parent(b, c). + parent(c, d). + ancestor($X, $Y) :- ancestor($X, $Z), parent($Z, $Y). + ancestor($X, $Y) :- parent($X, $Y). + ", + ); + + let query: Expr<'_> = parse::parse_str("ancestor($X, $Y).", &interner).unwrap(); + let mut answer = collect_answer(db.query(query)); + + let mut expected = [ + ["$X = a", "$Y = b"], + ["$X = a", "$Y = c"], + ["$X = a", "$Y = d"], + ["$X = b", "$Y = c"], + ["$X = b", "$Y = d"], + ["$X = c", "$Y = d"], + ]; + + answer.sort_unstable(); + expected.sort_unstable(); + assert_eq!(answer, expected); + } + #[test] fn test_discarding_uncomitted_change() { let mut db = Database::new(); diff --git a/crates/logic-eval/src/prove/mod.rs b/crates/logic-eval/src/prove/mod.rs index fae0be2..683d77c 100644 --- a/crates/logic-eval/src/prove/mod.rs +++ b/crates/logic-eval/src/prove/mod.rs @@ -1,4 +1,6 @@ +pub(crate) mod canonical; pub(crate) mod common; pub(crate) mod db; pub(crate) mod prover; pub(crate) mod repr; +pub(crate) mod table; diff --git a/crates/logic-eval/src/prove/prover.rs b/crates/logic-eval/src/prove/prover.rs index af17d06..f98731c 100644 --- a/crates/logic-eval/src/prove/prover.rs +++ b/crates/logic-eval/src/prove/prover.rs @@ -1,10 +1,15 @@ -use super::repr::{ - ApplyResult, ClauseId, ExprId, ExprKind, ExprView, TermDeepView, TermElem, TermId, TermStorage, - TermStorageLen, TermView, TermViewMut, UniqueTermArray, +use super::{ + canonical, + repr::{ + ApplyResult, ClauseId, ExprId, ExprKind, ExprView, TermDeepView, TermElem, TermId, + TermStorage, TermStorageLen, TermView, TermViewMut, UniqueTermArray, + }, + table::Table, }; use crate::{ parse::repr::{Expr, Predicate, Term}, - Atom, Map, VAR_PREFIX, + prove::table::{TableEntry, TableIndex}, + Atom, IndexMap, IndexSet, Map, VAR_PREFIX, }; use core::{ fmt::{self, Debug, Display, Write}, @@ -12,11 +17,9 @@ use core::{ iter, ops::{self, Range}, }; -use indexmap::IndexMap; +use smallvec::SmallVec; use std::collections::VecDeque; -pub(crate) type ClauseMap = IndexMap, Vec>; - #[derive(Debug)] pub(crate) struct Prover { uni_op: UnificationOperator, @@ -24,11 +27,8 @@ pub(crate) struct Prover { /// Nodes created during proof search. nodes: Vec, - /// Variable assignments. - /// - /// For example, `assignment[X] = a` means that `X(term id)` is assigned to `a(term id)`. If a - /// value is identical to its index, it means the term is not assigned to anything. - assignments: Vec, + /// Variable assignments (e.g. X = a, Y = z) + term_assigns: TermAssignments, /// A given query. query: ExprId, @@ -38,8 +38,11 @@ pub(crate) struct Prover { /// This could be used to find what terms these variables are assigned to. query_vars: Vec, + /// Previously returned query answers. + query_answers: Vec>, + /// Task queue containing node index. - queue: VecDeque, + queue: NodeQueue, /// A buffer containing mapping between variables and temporary variables. /// @@ -49,6 +52,9 @@ pub(crate) struct Prover { /// A monotonically increasing integer that is used for generating temporary variables. temp_var_int: u32, + + /// SLG resolution. + table: Table, } impl Prover { @@ -56,27 +62,32 @@ impl Prover { Self { uni_op: UnificationOperator::new(), nodes: Vec::new(), - assignments: Vec::new(), + term_assigns: TermAssignments::default(), query: ExprId(0), query_vars: Vec::new(), - queue: VecDeque::new(), + query_answers: Vec::new(), + queue: NodeQueue::default(), temp_var_buf: Map::default(), temp_var_int: 0, + table: Table::default(), } } fn clear(&mut self) { self.uni_op.clear(); self.nodes.clear(); - self.assignments.clear(); + self.term_assigns.clear(); self.query_vars.clear(); + self.query_answers.clear(); self.queue.clear(); + self.table.clear(); } pub(crate) fn prove<'a, T: Atom>( &'a mut self, query: Expr, - clause_map: &'a ClauseMap, + clauses: &'a IndexMap, Vec>, + table_clauses: &'a IndexSet>, stor: &'a mut TermStorage, nimap: &'a mut NameIntMap, ) -> ProveCx<'a, T> { @@ -90,48 +101,109 @@ impl Prover { stor.get_expr(self.query) .with_term(&mut |term: TermView<'_, Integer>| { - term.with_variable(&mut |term| self.query_vars.push(term.id)); + term.with_variable(|term| self.query_vars.push(term.id)); }); - self.nodes.push(Node { - kind: NodeKind::Expr(self.query), - uni_delta: 0..0, - parent: self.nodes.len(), - }); - self.queue.push_back(0); + let node_kind = NodeKind::Expr(self.query); + let node_parent = self.nodes.len(); + self.nodes.push(Node::new(node_kind, node_parent)); + self.queue.push(0); - ProveCx::new(self, clause_map, stor, nimap, old_stor_len, old_nimap_state) + ProveCx { + prover: self, + clauses, + table_clauses, + stor, + nimap, + old_stor_len, + old_nimap_state, + } } /// Evaluates the given node with all possible clauses in the clause dataset, then returns /// whether a proof search path is complete or not. /// /// If it reached an end of paths, it returns proof search result within `Some`. The proof - /// search result is either true or false, which means the expression in the given node is true - /// or not. + /// search result is either true or false, which means the expression in the given node is + /// evaluted as true or false. fn evaluate_node( &mut self, node_index: usize, - clause_map: &ClauseMap, + clauses: &IndexMap, Vec>, + table_clauses: &IndexSet>, stor: &mut TermStorage, ) -> Option { - let node = self.nodes[node_index].clone(); - let node_expr = match node.kind { + let node_expr = match self.nodes[node_index].kind { NodeKind::Expr(expr_id) => expr_id, NodeKind::Leaf(eval) => { - self.find_assignment(node_index); + self.find_assignments(node_index); + // On a successful proof, records the answer in the nearest ancestor-owned SLG + // table entry, then notifies all waiting consumers. + if eval { + self.update_answer_and_notify(node_index, stor); + } return Some(eval); } }; - let predicate = stor.get_expr(node_expr).leftmost_term().predicate(); - let similar_clauses = if let Some(v) = clause_map.get(&predicate) { - v.as_slice() - } else { - &[] - }; + let node_leftmost = stor.get_expr(node_expr).leftmost_term().id; + let node_leftmost_pred = stor.get_term(node_leftmost).predicate(); + let mut similar_clauses = &[][..]; + let mut clause_buf: SmallVec<[ClauseId; 1]> = SmallVec::new(); + + // === SLG path === + // * Table entry - Created from non-canonical leftmost term of the node. In tabling, + // we use canonical variables for table keys only. + + if table_clauses.contains(&node_leftmost_pred) { + let key = canonical::canonicalize_term_id(stor, node_leftmost); + if let Some((_, entry)) = self.table.get_mut(&key) { + entry.register_consumer(node_index); + + // No answers yet? the node may be woken up by notification. + let answer_offset = self.nodes[node_index].table_answer_offset; + let answers = entry.answers(answer_offset); + if answers.is_empty() { + return None; + } + let next_offset = answer_offset + 1; + self.nodes[node_index].table_answer_offset = next_offset; + + // Synthesizes an answer clause then let this to be unified with the current node. + let mut term = stor.get_term_mut(node_leftmost); + let vars = term.as_view().collect_variables(); + for (var, answer) in vars.into_iter().zip(answers) { + term.replace(var, *answer); + } + clause_buf.push(ClauseId { + head: term.id(), + body: None, + }); + similar_clauses = &clause_buf[..]; + + // More answers? We'll handle them next time. + if !entry.answers(next_offset).is_empty() { + self.queue.push(node_index); + } + } else { + // First encounter: Just creates a table entry then proceeds with SLD. + if let Some(entry) = TableEntry::from_term_view(&stor.get_term(node_leftmost)) { + let index = self.table.register(key, entry); + self.nodes[node_index].table_owner = Some(index); + } + } + } + + // === BFS based SLD path === + + if similar_clauses.is_empty() { + if let Some(v) = clauses.get(&node_leftmost_pred) { + similar_clauses = v.as_slice() + } + } let old_len = self.nodes.len(); + for clause in similar_clauses { let head = stor.get_term(clause.head); @@ -147,14 +219,14 @@ impl Prover { ); if let Some(new_node) = self.unify_node_with_clause(node_index, clause, stor) { self.nodes.push(new_node); - self.queue.push_back(self.nodes.len() - 1); + self.queue.push(self.nodes.len() - 1); } } // We may need to apply true or false to the leftmost term of the node expression due to // unification failure or exhaustive search. // - Unification failure means the leftmost term should be false. - // - But we need to consider exhaustive search possibility at the same time. + // - But we need to consider exhaustive search at the same time. let expr = stor.get_expr(node_expr); let eval = self.nodes.len() > old_len; @@ -172,16 +244,13 @@ impl Prover { if let Some(to) = need_apply { let mut expr = stor.get_expr_mut(node_expr); - let kind = match expr.apply_to_leftmost_term(to) { + let node_kind = match expr.apply_to_leftmost_term(to) { ApplyResult::Expr => NodeKind::Expr(expr.id()), ApplyResult::Complete(eval) => NodeKind::Leaf(eval), }; - self.nodes.push(Node { - kind, - uni_delta: 0..0, - parent: node_index, - }); - self.queue.push_back(self.nodes.len() - 1); + let node_parent = node_index; + self.nodes.push(Node::new(node_kind, node_parent)); + self.queue.push(self.nodes.len() - 1); } return None; @@ -253,6 +322,45 @@ impl Prover { } } + /// Finds the nearest ancestor node that owns SLG table entry, then updates the entry and + /// notifies all waiting consumers. + fn update_answer_and_notify(&mut self, node_index: usize, stor: &TermStorage) { + let tabled_ancestor = { + let mut cur = node_index; + loop { + if self.nodes[cur].table_owner.is_some() { + break Some(cur); + } + let parent = self.nodes[cur].parent; + if parent == cur { + break None; + } + cur = parent; + } + }; + + if let Some(ancestor) = tabled_ancestor { + let table_index = self.nodes[ancestor].table_owner.unwrap(); + let entry = &mut self.table[table_index]; + let all_answers_concrete = entry.variables().iter().all(|&var| { + if let Some(answer) = self.term_assigns.find(var) { + !stor.get_term(answer).contains_variable() + } else { + false + } + }); + + if all_answers_concrete && !entry.has_answer(&self.term_assigns) { + entry.update_answer(&self.term_assigns); + for i in entry.consumer_nodes() { + if i != node_index { + self.queue.push(i); + } + } + } + } + } + /// Replaces variables in a clause with other temporary variables. // // Why we replace variables with temporary variables in clauses before unifying? @@ -328,52 +436,39 @@ impl Prover { { return None; } - let (node_expr, clause, uni_delta) = self.uni_op.consume_ops(stor, node_expr, clause); + let (node_expr, clause, uni_history) = self.uni_op.consume_ops(stor, node_expr, clause); if let Some(body) = clause.body { let mut lhs = stor.get_expr_mut(node_expr); lhs.replace_leftmost_term(body); - return Some(Node { - kind: NodeKind::Expr(lhs.id()), - uni_delta, - parent: node_index, - }); + let node_kind = NodeKind::Expr(lhs.id()); + let node_parent = node_index; + let node = Node::new(node_kind, node_parent).with_unification_history(uni_history); + return Some(node); } let mut lhs = stor.get_expr_mut(node_expr); - let kind = match lhs.apply_to_leftmost_term(true) { + let node_kind = match lhs.apply_to_leftmost_term(true) { ApplyResult::Expr => NodeKind::Expr(lhs.id()), ApplyResult::Complete(eval) => NodeKind::Leaf(eval), }; - Some(Node { - kind, - uni_delta, - parent: node_index, - }) + let node_parent = node_index; + let node = Node::new(node_kind, node_parent).with_unification_history(uni_history); + Some(node) } - /// Finds all assignments from the given node to the root node. - /// - /// Then, the assignment information is stored at [`Self::assignments`]. - fn find_assignment(&mut self, node_index: usize) { - // Collects unification records. - self.assignments.clear(); + /// Finds all from/to relations while traversing from the given node to the root node then add + /// the relations to [`TermAssignments`]. + fn find_assignments(&mut self, node_index: usize) { + self.term_assigns.clear(); let mut cur_index = node_index; loop { let node = &self.nodes[cur_index]; - let range = node.uni_delta.clone(); + let range = node.uni_history.clone(); for (from, to) in self.uni_op.get_record(range).iter().cloned() { - let (from, to) = (from.0, to.0); - - for i in self.assignments.len()..=from.max(to) { - self.assignments.push(i); - } - - let root_from = find(&mut self.assignments, from); - let root_to = find(&mut self.assignments, to); - self.assignments[root_from] = root_to; + self.term_assigns.add(from, to); } if node.parent == cur_index { @@ -381,31 +476,86 @@ impl Prover { } cur_index = node.parent; } + } - return; + /// Records the current proof result as a query answer if it is ground and not duplicated, + /// then returns whether a new answer was recorded. + fn record_query_answer(&mut self, stor: &mut TermStorage) -> bool { + let mut answer = Vec::with_capacity(self.query_vars.len()); + for &var in &self.query_vars { + let Some(resolved) = self.materialize_assigned_term(var, stor) else { + return false; + }; + answer.push(resolved); + } - // === Internal helper functions === + // no query vars -> empty iter -> all() returns true + if self.query_answers.iter().all(|seen| seen != &answer) { + self.query_answers.push(answer); + true + } else { + false + } + } - fn find(buf: &mut [usize], i: usize) -> usize { - if buf[i] == i { - i - } else { - let root = find(buf, buf[i]); - buf[i] = root; - root + /// Builds a fully substituted term for a query-side term from `term_assigns`. + /// + /// Examples: + /// + /// | assignments | input | output | + /// | ------------------- | :------: | :------: | + /// | `T = Vec(a)` | `T` | `Vec(a)` | + /// | `T = a` | `Vec(T)` | `Vec(a)` | + /// | `T = Vec(U), U = a` | `T` | `Vec(a)` | + /// | `T = Vec(U)` | `T` | `None` | + /// + /// This must materialize the whole term tree, not just rewrite functors in place. The returned + /// `TermId` always points to a ground term inserted into `stor`. + fn materialize_assigned_term( + &self, + term_id: TermId, + stor: &mut TermStorage, + ) -> Option { + let term = stor.get_term(term_id); + if term.is_variable() { + let resolved = self.term_assigns.find(term_id)?; + if resolved == term_id { + return None; } + return self.materialize_assigned_term(resolved, stor); } + + let functor = *term.functor(); + let arg_ids = term.args().map(|arg| arg.id).collect::>(); + let args = arg_ids + .into_iter() + .map(|arg_id| { + self.materialize_assigned_term(arg_id, stor) + .map(|id| stor.get_term(id).deserialize()) + }) + .collect::>>()?; + + let materialized = Term { functor, args }; + Some(stor.insert_term(materialized)) } } +/// Manages unification operations between a goal(node) and a clause during SLG resolution. +/// +/// You can make unification operations, [`UnifyOp`]s, by unifying the leftmost term of the goal and +/// the head of the clause. Append the operations in order to apply them to the whole goal and +/// clause. You can apply them at once via [`consume_ops`]. +/// +/// [`consume_ops`]: Self::consume_ops #[derive(Debug)] struct UnificationOperator { + /// Buffered unification operations. ops: Vec, - /// History of unification. + /// Unification history. /// - /// A pair of term ids means that `pair.0` is assiend to `pair.1`. For example, `(X, a)` means - /// `X` is assigned to `a`. + /// This is a record of `(from, to)` pairs. It means there has been unification that substitute + /// the `from` with `to`. For example, `(X, a)` means the variable `X` was substituted with `a`. record: Vec<(TermId, TermId)>, } @@ -426,6 +576,13 @@ impl UnificationOperator { self.ops.push(op); } + /// Returns + /// * `ExprId` - Operation applied `left` + /// * `ClauseId` - Operation applied `right` + /// * `Range` - A range of unification history(from/to pairs). You can retrieve the + /// from/to pairs via [`get_record`] + /// + /// [`get_record`]: Self::get_record #[must_use] fn consume_ops( &mut self, @@ -464,11 +621,64 @@ impl UnificationOperator { } } +#[derive(Debug, Default)] +struct NodeQueue { + inner: VecDeque, +} + +impl NodeQueue { + fn clear(&mut self) { + self.inner.clear(); + } + + fn contains(&self, node_index: &usize) -> bool { + self.inner.contains(node_index) + } + + fn push(&mut self, node_index: usize) { + if !self.contains(&node_index) { + self.inner.push_back(node_index); + } + } + + fn pop(&mut self) -> Option { + self.inner.pop_front() + } +} + #[derive(Debug, Clone)] struct Node { kind: NodeKind, - uni_delta: Range, parent: usize, + + /// A range of unification history that applied to prove this node: + /// Pairs of from([`TermId`]) -> to([`TermId`]). + /// + /// You can retreive the from/to pairs via [`UnificationOperator::get_record`]. + uni_history: Range, + + /// Table entry owned by this node, if this node is the producer of a tabled subgoal. + table_owner: Option, + + /// Number of answers already consumed from a table entry. + table_answer_offset: usize, +} + +impl Node { + fn new(kind: NodeKind, parent: usize) -> Self { + Self { + kind, + parent, + uni_history: 0..0, + table_owner: None, + table_answer_offset: 0, + } + } + + fn with_unification_history(mut self, uni_history: Range) -> Self { + self.uni_history = uni_history; + self + } } #[derive(Debug, Clone, Copy)] @@ -480,15 +690,75 @@ enum NodeKind { Leaf(bool), } +#[derive(Debug, Default)] +pub(crate) struct TermAssignments { + /// Union-find from-to relations. + /// + /// # Examples + /// `roots[a]: a` means TermId(a) is not unified with anything. + /// `roots[v]: w` means TermId(v) is a variable and it is unified with TermId(w). + relations: Vec, +} + +impl TermAssignments { + pub(crate) fn find(&self, from: TermId) -> Option { + let to = *self.relations.get(from.0)?; + if from == to { + Some(to) + } else { + self.find(to) + } + } + + pub(crate) fn find_optimize(&mut self, from: TermId) -> TermId { + let new_len = from.0 + 1; + for i in self.len()..new_len { + self.relations.push(TermId(i)); + } + + let to = self.relations[from.0]; + if from == to { + to + } else { + let root = self.find_optimize(to); + self.relations[from.0] = root; + root + } + } + + fn len(&self) -> usize { + self.relations.len() + } + + fn clear(&mut self) { + self.relations.clear(); + } + + fn add(&mut self, from: TermId, to: TermId) { + let root_from = self.find_optimize(from); + let root_to = self.find_optimize(to); + self.relations[root_from.0] = root_to; + } +} + +/// Unification operation between `node expr - clause's body(expr)`. #[derive(Debug)] enum UnifyOp { + /// Unification operation that rewrites the goal expression on the query side. + /// + /// Substitues all `from`s in the goal expression with `to`. Left { from: TermId, to: TermId }, + + /// Unification operation that rewrites the clause body on the clause side. + /// + /// Substitues all `from`s in the clause's body with `to`. Right { from: TermId, to: TermId }, } pub struct ProveCx<'a, T: Atom> { prover: &'a mut Prover, - clause_map: &'a ClauseMap, + clauses: &'a IndexMap, Vec>, + table_clauses: &'a IndexSet>, stor: &'a mut TermStorage, nimap: &'a mut NameIntMap, old_stor_len: TermStorageLen, @@ -496,37 +766,19 @@ pub struct ProveCx<'a, T: Atom> { } impl<'a, T: Atom> ProveCx<'a, T> { - fn new( - prover: &'a mut Prover, - clause_map: &'a ClauseMap, - stor: &'a mut TermStorage, - nimap: &'a mut NameIntMap, - old_stor_len: TermStorageLen, - old_nimap_state: NameIntMapState, - ) -> Self { - Self { - prover, - clause_map, - stor, - nimap, - old_stor_len, - old_nimap_state, - } - } - pub fn prove_next(&mut self) -> Option> { - while let Some(node_index) = self.prover.queue.pop_front() { + while let Some(node_index) = self.prover.queue.pop() { if let Some(proof_result) = self.prover - .evaluate_node(node_index, self.clause_map, self.stor) + .evaluate_node(node_index, self.clauses, self.table_clauses, self.stor) { // Returns Some(EvalView) only if the result is TRUE. - if proof_result { + if proof_result && self.prover.record_query_answer(self.stor) { return Some(EvalView { query_vars: &self.prover.query_vars, terms: &self.stor.terms.buf, - assignments: &self.prover.assignments, - int2name: &self.nimap.int2name, + term_assigns: &self.prover.term_assigns, + nimap: self.nimap, start: 0, end: self.prover.query_vars.len(), }); @@ -551,8 +803,8 @@ impl Drop for ProveCx<'_, T> { pub struct EvalView<'a, T> { query_vars: &'a [TermId], terms: &'a [TermElem], - assignments: &'a [usize], - int2name: &'a IndexMap, + term_assigns: &'a TermAssignments, + nimap: &'a NameIntMap, /// Inclusive start: usize, /// Exclusive @@ -576,8 +828,8 @@ impl<'a, T> Iterator for EvalView<'a, T> { Some(Assignment { buf: self.terms, from, - assignments: self.assignments, - int2name: self.int2name, + term_assigns: self.term_assigns, + nimap: self.nimap, }) } else { None @@ -601,8 +853,33 @@ impl iter::FusedIterator for EvalView<'_, T> {} pub struct Assignment<'a, T> { buf: &'a [TermElem], from: TermId, - assignments: &'a [usize], - int2name: &'a IndexMap, + term_assigns: &'a TermAssignments, + nimap: &'a NameIntMap, +} + +impl<'a, T: 'a> Assignment<'a, T> { + /// Returns left hand side variable name of the assignment. + /// + /// Note that assignment's left hand side is always variable. + pub fn get_lhs_variable(&self) -> &T { + let int = self.lhs_view().find_variable().unwrap(); + self.nimap.get_name(&int).unwrap() + } + + const fn lhs_view(&self) -> TermView<'_, Integer> { + TermView { + buf: self.buf, + id: self.from, + } + } + + const fn rhs_view(&self) -> TermDeepView<'_, Integer> { + TermDeepView { + buf: self.buf, + term_assigns: self.term_assigns, + id: self.from, + } + } } impl<'a, T: Atom + 'a> Assignment<'a, T> { @@ -610,29 +887,21 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { /// /// To create a term, this method could allocate memory for the term. pub fn lhs(&self) -> Term { - Self::term_view_to_term(self.lhs_view(), self.int2name) + Self::term_view_to_term(self.lhs_view(), self.nimap) } /// Creates right hand side term of the assignment. /// /// To create a term, this method could allocate memory for the term. pub fn rhs(&self) -> Term { - Self::term_deep_view_to_term(self.rhs_view(), self.int2name) + Self::term_deep_view_to_term(self.rhs_view(), self.nimap) } - /// Returns left hand side variable name of the assignment. - /// - /// Note that assignment's left hand side is always variable. - pub fn get_lhs_variable(&self) -> &T { - let int = self.lhs_view().get_contained_variable().unwrap(); - self.int2name.get(&int).unwrap() - } - - fn term_view_to_term(view: TermView<'_, Integer>, int2name: &IndexMap) -> Term { + fn term_view_to_term(view: TermView<'_, Integer>, nimap: &NameIntMap) -> Term { let functor = view.functor(); let args = view.args(); - let functor = if let Some(name) = int2name.get(functor) { + let functor = if let Some(name) = nimap.get_name(functor) { name.clone() } else { unreachable!("integer {:?} has no name mapping", functor) @@ -640,20 +909,17 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { let args = args .into_iter() - .map(|arg| Self::term_view_to_term(arg, int2name)) + .map(|arg| Self::term_view_to_term(arg, nimap)) .collect(); Term { functor, args } } - fn term_deep_view_to_term( - view: TermDeepView<'_, Integer>, - int2name: &IndexMap, - ) -> Term { + fn term_deep_view_to_term(view: TermDeepView<'_, Integer>, nimap: &NameIntMap) -> Term { let functor = view.functor(); let args = view.args(); - let functor = if let Some(name) = int2name.get(functor) { + let functor = if let Some(name) = nimap.get_name(functor) { name.clone() } else { unreachable!("integer {:?} has no name mapping", functor) @@ -661,44 +927,29 @@ impl<'a, T: Atom + 'a> Assignment<'a, T> { let args = args .into_iter() - .map(|arg| Self::term_deep_view_to_term(arg, int2name)) + .map(|arg| Self::term_deep_view_to_term(arg, nimap)) .collect(); Term { functor, args } } - - const fn lhs_view(&self) -> TermView<'_, Integer> { - TermView { - buf: self.buf, - id: self.from, - } - } - - const fn rhs_view(&self) -> TermDeepView<'_, Integer> { - TermDeepView { - buf: self.buf, - links: self.assignments, - id: self.from, - } - } } impl Display for Assignment<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let view = format::NamedTermView::new(self.lhs_view(), self.int2name); + let view = format::NamedTermView::new(self.lhs_view(), self.nimap); Display::fmt(&view, f)?; f.write_str(" = ")?; - let view = format::NamedTermDeepView::new(self.rhs_view(), self.int2name); + let view = format::NamedTermDeepView::new(self.rhs_view(), self.nimap); Display::fmt(&view, f) } } impl Debug for Assignment<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let lhs = format::NamedTermView::new(self.lhs_view(), self.int2name); - let rhs = format::NamedTermDeepView::new(self.rhs_view(), self.int2name); + let lhs = format::NamedTermView::new(self.lhs_view(), self.nimap); + let rhs = format::NamedTermDeepView::new(self.rhs_view(), self.nimap); f.debug_struct("Assignment") .field("lhs", &lhs) @@ -774,38 +1025,6 @@ impl TermView<'_, Integer> { false } } - - /// Returns true if this term is a variable. - /// - /// e.g. Terms like `X`, `Y` will return true. - fn is_variable(&self) -> bool { - self.arity() == 0 && self.functor().is_variable() - } - - /// Returns true if this term is a variable or contains variable in it. - /// - /// e.g. Terms like `X` of `f(X)` will return true. - fn contains_variable(&self) -> bool { - self.is_variable() || self.args().any(|arg| arg.contains_variable()) - } - - fn get_contained_variable(&self) -> Option { - if self.is_variable() { - Some(*self.functor()) - } else { - self.args().find_map(|arg| arg.get_contained_variable()) - } - } - - fn with_variable(&self, f: &mut F) { - if self.is_variable() { - f(self); - } else { - for arg in self.args() { - arg.with_variable(f); - } - } - } } impl TermViewMut<'_, Integer> { @@ -828,20 +1047,30 @@ impl Integer { Self(index) } - pub(crate) const fn temporary(int: u32) -> Self { - Self(int | Self::VAR_FLAG | Self::TEMPORARY_FLAG) + pub(crate) fn variable(int: u32) -> Self { + let mask = Self::VAR_FLAG; + debug_assert_eq!(int & mask, 0); + Self(int | mask) } - pub(crate) const fn is_variable(self) -> bool { - (Self::VAR_FLAG & self.0) == Self::VAR_FLAG + pub(crate) fn temporary(int: u32) -> Self { + let mask = Self::VAR_FLAG | Self::TEMPORARY_FLAG; + debug_assert_eq!(int & mask, 0); + Self(int | mask) } pub(crate) const fn is_temporary_variable(self) -> bool { - let mask: u32 = Self::VAR_FLAG | Self::TEMPORARY_FLAG; + let mask = Self::VAR_FLAG | Self::TEMPORARY_FLAG; (mask & self.0) == mask } } +impl Atom for Integer { + fn is_variable(&self) -> bool { + (Self::VAR_FLAG & self.0) == Self::VAR_FLAG + } +} + impl Debug for Integer { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mask: u32 = Self::VAR_FLAG | Self::TEMPORARY_FLAG; @@ -865,23 +1094,14 @@ impl ops::AddAssign for Integer { /// Only mapping of user-input clauses and queries are stored in this map. Auto-generated variables /// or something like that are not stored here. +#[derive(Debug)] pub(crate) struct NameIntMap { - pub(crate) name2int: IndexMap, - pub(crate) int2name: IndexMap, + name2int: IndexMap, + int2name: IndexMap, next_int: u32, } -impl fmt::Debug for NameIntMap { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("NameIntMap") - .field("name2int", &self.name2int) - .field("int2name", &self.int2name) - .field("next_int", &self.next_int) - .finish() - } -} - -impl NameIntMap { +impl NameIntMap { pub(crate) fn new() -> Self { Self { name2int: IndexMap::default(), @@ -890,18 +1110,8 @@ impl NameIntMap { } } - pub(crate) fn name_to_int(&mut self, name: T) -> Integer { - if let Some(int) = self.name2int.get(&name) { - *int - } else { - let int = Integer::from_value(&name, self.next_int); - - self.name2int.insert(name.clone(), int); - self.int2name.insert(int, name); - - self.next_int += 1; - int - } + pub(crate) fn get_name(&self, int: &Integer) -> Option<&T> { + self.int2name.get(int) } pub(crate) fn state(&self) -> NameIntMapState { @@ -926,6 +1136,22 @@ impl NameIntMap { } } +impl NameIntMap { + pub(crate) fn name_to_int(&mut self, name: T) -> Integer { + if let Some(int) = self.name2int.get(&name) { + *int + } else { + let int = Integer::from_value(&name, self.next_int); + + self.name2int.insert(name.clone(), int); + self.int2name.insert(int, name); + + self.next_int += 1; + int + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct NameIntMapState { name2int_len: usize, @@ -938,20 +1164,26 @@ pub(crate) mod format { pub struct NamedTermView<'a, T> { view: TermView<'a, Integer>, - int2name: &'a IndexMap, + nimap: &'a NameIntMap, } - impl<'a, T: Atom> NamedTermView<'a, T> { - pub(crate) const fn new( - view: TermView<'a, Integer>, - int2name: &'a IndexMap, - ) -> Self { - Self { view, int2name } + impl<'a, T> NamedTermView<'a, T> { + pub(crate) const fn new(view: TermView<'a, Integer>, nimap: &'a NameIntMap) -> Self { + Self { view, nimap } } + fn args<'s>(&'s self) -> impl Iterator> + 's { + self.view.args().map(|arg| Self { + view: arg, + nimap: self.nimap, + }) + } + } + + impl<'a, T: Atom> NamedTermView<'a, T> { pub fn is(&self, term: &Term) -> bool { let functor = self.view.functor(); - let Some(functor) = self.int2name.get(functor) else { + let Some(functor) = self.nimap.get_name(functor) else { return false; }; @@ -969,29 +1201,22 @@ pub(crate) mod format { self.args().any(|arg| arg.contains(term)) } - - fn args<'s>(&'s self) -> impl Iterator> + 's { - self.view.args().map(|arg| Self { - view: arg, - int2name: self.int2name, - }) - } } - impl<'a, T: Atom + Display> Display for NamedTermView<'a, T> { + impl<'a, T: Display> Display for NamedTermView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, int2name } = self; + let Self { view, nimap } = self; let functor = view.functor(); let args = view.args(); let num_args = args.len(); - write_int(functor, int2name, f)?; + write_int(functor, nimap, f)?; if num_args > 0 { f.write_char('(')?; for (i, arg) in args.enumerate() { - fmt::Display::fmt(&Self::new(arg, int2name), f)?; + fmt::Display::fmt(&Self::new(arg, nimap), f)?; if i + 1 < num_args { f.write_str(", ")?; } @@ -1002,22 +1227,22 @@ pub(crate) mod format { } } - impl<'a, T: Atom + Debug> Debug for NamedTermView<'a, T> { + impl<'a, T: Debug> Debug for NamedTermView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, int2name } = self; + let Self { view, nimap } = self; let functor = view.functor(); let args = view.args(); let num_args = args.len(); if num_args == 0 { - if let Some(name) = int2name.get(functor) { + if let Some(name) = nimap.get_name(functor) { fmt::Debug::fmt(name, f) } else { fmt::Debug::fmt(functor, f) } } else { - let name_str = if let Some(name) = int2name.get(functor) { + let name_str = if let Some(name) = nimap.get_name(functor) { format!("{:?}", name) } else { format!("{:?}", functor) @@ -1025,7 +1250,7 @@ pub(crate) mod format { let mut d = f.debug_tuple(&name_str); for arg in args { - d.field(&Self::new(arg, int2name)); + d.field(&Self::new(arg, nimap)); } d.finish() } @@ -1034,32 +1259,29 @@ pub(crate) mod format { pub(crate) struct NamedTermDeepView<'a, T> { view: TermDeepView<'a, Integer>, - int2name: &'a IndexMap, + nimap: &'a NameIntMap, } impl<'a, T> NamedTermDeepView<'a, T> { - pub(crate) const fn new( - view: TermDeepView<'a, Integer>, - int2name: &'a IndexMap, - ) -> Self { - Self { view, int2name } + pub(crate) const fn new(view: TermDeepView<'a, Integer>, nimap: &'a NameIntMap) -> Self { + Self { view, nimap } } } impl<'a, T: Display> Display for NamedTermDeepView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, int2name } = self; + let Self { view, nimap } = self; let functor = view.functor(); let args = view.args(); let num_args = args.len(); - write_int(functor, int2name, f)?; + write_int(functor, nimap, f)?; if num_args > 0 { f.write_char('(')?; for (i, arg) in args.enumerate() { - fmt::Display::fmt(&Self::new(arg, int2name), f)?; + fmt::Display::fmt(&Self::new(arg, nimap), f)?; if i + 1 < num_args { f.write_str(", ")?; } @@ -1072,20 +1294,20 @@ pub(crate) mod format { impl<'a, T: Debug> Debug for NamedTermDeepView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, int2name } = self; + let Self { view, nimap } = self; let functor = view.functor(); let args = view.args(); let num_args = args.len(); if num_args == 0 { - if let Some(name) = int2name.get(functor) { + if let Some(name) = nimap.get_name(functor) { fmt::Debug::fmt(name, f) } else { fmt::Debug::fmt(functor, f) } } else { - let name_str = if let Some(name) = int2name.get(functor) { + let name_str = if let Some(name) = nimap.get_name(functor) { format!("{:?}", name) } else { format!("{:?}", functor) @@ -1093,7 +1315,7 @@ pub(crate) mod format { let mut d = f.debug_tuple(&name_str); for arg in args { - d.field(&Self::new(arg, int2name)); + d.field(&Self::new(arg, nimap)); } d.finish() } @@ -1102,33 +1324,32 @@ pub(crate) mod format { pub struct NamedExprView<'a, T> { view: ExprView<'a, Integer>, - int2name: &'a IndexMap, + nimap: &'a NameIntMap, } - impl<'a, T: Atom> NamedExprView<'a, T> { - pub(crate) const fn new( - view: ExprView<'a, Integer>, - int2name: &'a IndexMap, - ) -> Self { - Self { view, int2name } + impl<'a, T> NamedExprView<'a, T> { + pub(crate) const fn new(view: ExprView<'a, Integer>, nimap: &'a NameIntMap) -> Self { + Self { view, nimap } } + } + impl<'a, T: Atom> NamedExprView<'a, T> { pub fn contains_term(&self, term: &Term) -> bool { match self.view.as_kind() { ExprKind::Term(view) => NamedTermView { view, - int2name: self.int2name, + nimap: self.nimap, } .contains(term), ExprKind::Not(view) => NamedExprView { view, - int2name: self.int2name, + nimap: self.nimap, } .contains_term(term), ExprKind::And(args) | ExprKind::Or(args) => args.into_iter().any(|view| { NamedExprView { view, - int2name: self.int2name, + nimap: self.nimap, } .contains_term(term) }), @@ -1136,26 +1357,20 @@ pub(crate) mod format { } } - impl<'a, T: Atom + Display> Display for NamedExprView<'a, T> { + impl<'a, T: Display> Display for NamedExprView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, int2name } = self; + let Self { view, nimap } = self; match view.as_kind() { - ExprKind::Term(term) => fmt::Display::fmt( - &NamedTermView { - view: term, - int2name, - }, - f, - )?, + ExprKind::Term(term) => fmt::Display::fmt(&NamedTermView { view: term, nimap }, f)?, ExprKind::Not(inner) => { f.write_str("\\+ ")?; if matches!(inner.as_kind(), ExprKind::And(_) | ExprKind::Or(_)) { f.write_char('(')?; - fmt::Display::fmt(&Self::new(inner, int2name), f)?; + fmt::Display::fmt(&Self::new(inner, nimap), f)?; f.write_char(')')?; } else { - fmt::Display::fmt(&Self::new(inner, int2name), f)?; + fmt::Display::fmt(&Self::new(inner, nimap), f)?; } } ExprKind::And(args) => { @@ -1163,10 +1378,10 @@ pub(crate) mod format { for (i, arg) in args.enumerate() { if matches!(arg.as_kind(), ExprKind::Or(_)) { f.write_char('(')?; - fmt::Display::fmt(&Self::new(arg, int2name), f)?; + fmt::Display::fmt(&Self::new(arg, nimap), f)?; f.write_char(')')?; } else { - fmt::Display::fmt(&Self::new(arg, int2name), f)?; + fmt::Display::fmt(&Self::new(arg, nimap), f)?; } if i + 1 < num_args { f.write_str(", ")?; @@ -1176,7 +1391,7 @@ pub(crate) mod format { ExprKind::Or(args) => { let num_args = args.len(); for (i, arg) in args.enumerate() { - fmt::Display::fmt(&Self::new(arg, int2name), f)?; + fmt::Display::fmt(&Self::new(arg, nimap), f)?; if i + 1 < num_args { f.write_str("; ")?; } @@ -1187,27 +1402,27 @@ pub(crate) mod format { } } - impl<'a, T: Atom + Debug> Debug for NamedExprView<'a, T> { + impl<'a, T: Debug> Debug for NamedExprView<'a, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let Self { view, int2name } = self; + let Self { view, nimap } = self; match view.as_kind() { - ExprKind::Term(term) => fmt::Debug::fmt(&NamedTermView::new(term, int2name), f), + ExprKind::Term(term) => fmt::Debug::fmt(&NamedTermView::new(term, nimap), f), ExprKind::Not(inner) => f .debug_tuple("Not") - .field(&NamedExprView::new(inner, int2name)) + .field(&NamedExprView::new(inner, nimap)) .finish(), ExprKind::And(args) => { let mut d = f.debug_tuple("And"); for arg in args { - d.field(&NamedExprView::new(arg, int2name)); + d.field(&NamedExprView::new(arg, nimap)); } d.finish() } ExprKind::Or(args) => { let mut d = f.debug_tuple("Or"); for arg in args { - d.field(&NamedExprView::new(arg, int2name)); + d.field(&NamedExprView::new(arg, nimap)); } d.finish() } @@ -1217,10 +1432,10 @@ pub(crate) mod format { fn write_int( int: &Integer, - map: &IndexMap, + nimap: &NameIntMap, f: &mut fmt::Formatter<'_>, ) -> fmt::Result { - if let Some(name) = map.get(int) { + if let Some(name) = nimap.get_name(int) { fmt::Display::fmt(name, f) } else { fmt::Debug::fmt(int, f) diff --git a/crates/logic-eval/src/prove/repr.rs b/crates/logic-eval/src/prove/repr.rs index 675a0b5..b681b26 100644 --- a/crates/logic-eval/src/prove/repr.rs +++ b/crates/logic-eval/src/prove/repr.rs @@ -1,6 +1,5 @@ -use crate::{Expr, PassThroughState, Predicate, Term}; +use crate::{prove::prover::TermAssignments, Atom, Expr, PassThroughIndexMap, Predicate, Term}; use fxhash::FxHasher; -use indexmap::IndexMap; use std::{ hash::{Hash, Hasher}, iter, ops, @@ -18,12 +17,6 @@ pub(crate) struct TermStorage { pub(crate) terms: UniqueTermArray, } -#[derive(Debug, Clone, PartialEq, Eq)] -pub(crate) struct TermStorageLen { - expr_len: usize, - term_len: TermArrayLen, -} - impl TermStorage { pub(crate) fn new() -> Self { Self { @@ -71,6 +64,12 @@ impl TermStorage { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct TermStorageLen { + expr_len: usize, + term_len: TermArrayLen, +} + #[derive(Debug)] pub(crate) struct ExprArray { buf: Vec, @@ -114,31 +113,31 @@ impl ExprArray { self.buf.push(elem); id } - Expr::Not(expr) => { + Expr::Not(arg) => { let idx = self.reserve(1); - let inner_id = self.insert(*expr, term_arr); + let inner_id = self.insert(*arg, term_arr); self.buf[idx] = ExprElem::Not(inner_id); ExprId(idx) } - Expr::And(exprs) => { - let num_args = exprs.len(); + Expr::And(args) => { + let num_args = args.len(); let idx = self.reserve(1 + num_args); self.buf[idx] = ExprElem::And { len: num_args }; - for (i, expr) in exprs.into_iter().enumerate() { - let arg_id = self.insert(expr, term_arr); + for (i, arg) in args.into_iter().enumerate() { + let arg_id = self.insert(arg, term_arr); self.buf[idx + 1 + i] = ExprElem::Expr(arg_id); } ExprId(idx) } - Expr::Or(exprs) => { - let num_args = exprs.len(); + Expr::Or(args) => { + let num_args = args.len(); let idx = self.reserve(1 + num_args); self.buf[idx] = ExprElem::Or { len: num_args }; - for (i, expr) in exprs.into_iter().enumerate() { - let arg_id = self.insert(expr, term_arr); + for (i, arg) in args.into_iter().enumerate() { + let arg_id = self.insert(arg, term_arr); self.buf[idx + 1 + i] = ExprElem::Expr(arg_id); } @@ -355,6 +354,44 @@ impl<'a, T> ExprViewMut<'a, T> { self.id } + pub(crate) fn with_terminal(&mut self, mut f: F) + where + F: FnMut(&mut UniqueTermArray, TermId), + { + fn helper<'a, T, F>(this: &mut ExprViewMut<'a, T>, f: &mut F) + where + F: FnMut(&mut UniqueTermArray, TermId), + { + this.find_then_move(); + + match this.exprs[this.id] { + ExprElem::Term(term) => { + TermViewMut { + arr: this.terms, + id: term, + } + .with_terminal(f); + } + ExprElem::Not(inner) => { + let org = this.id; + this.id = inner; + helper(this, f); + this.id = org; + } + ExprElem::And { len } | ExprElem::Or { len } => { + let org = this.id; + for _ in 0..len { + this.id += 1; + helper(this, f); + } + this.id = org; + } + ExprElem::Expr(_) => unreachable!(), + } + } + helper(self, &mut f) + } + /// Finds the destination of jump (Elem::Expr) chain then moves this view to the final /// expression. fn find_then_move(&mut self) { @@ -370,38 +407,6 @@ impl<'a, T> ExprViewMut<'a, T> { src } } - - pub(crate) fn with_terminal(&mut self, f: &mut F) - where - F: FnMut(&mut UniqueTermArray, TermId), - { - self.find_then_move(); - - match self.exprs[self.id] { - ExprElem::Term(term) => { - TermViewMut { - arr: self.terms, - id: term, - } - .with_terminal(f); - } - ExprElem::Not(inner) => { - let org = self.id; - self.id = inner; - self.with_terminal(f); - self.id = org; - } - ExprElem::And { len } | ExprElem::Or { len } => { - let org = self.id; - for _ in 0..len { - self.id += 1; - self.with_terminal(f); - } - self.id = org; - } - ExprElem::Expr(_) => unreachable!(), - } - } } #[derive(Debug, PartialEq, Eq)] @@ -616,7 +621,7 @@ impl<'a, T: Clone> ExprViewMut<'a, T> { } } -impl<'a, T: Clone + Eq + Hash> ExprViewMut<'a, T> { +impl<'a, T: Atom> ExprViewMut<'a, T> { /// If this expression contains `from`, then replaces them to `to` in a clone-on-write way. /// /// If the replacement took place, then a new expression is created then this view becomes to @@ -701,14 +706,14 @@ pub(crate) struct UniqueTermArray { /// /// You are encouraged to call two methods below to access this field, [`Self::add_mapping`] and /// [`Self::get_similar`], which hide the problem. - pub(crate) map: IndexMap, PassThroughState>, + pub(crate) map: PassThroughIndexMap>, } impl UniqueTermArray { fn new() -> Self { Self { buf: Vec::new(), - map: IndexMap::default(), + map: PassThroughIndexMap::default(), } } @@ -772,6 +777,28 @@ impl UniqueTermArray { } } + /// If there's no structurally identical term in the array, then returns Ok with the `new_id`. + /// Otherwise, returns Err with exisging id. + /// + /// If result is Ok, then registers the `new_id` in the map. + fn try_register_unique_term(&mut self, new_id: TermId) -> Result { + let hash = buf_term_hash(&self.buf, new_id); + let dup = 'search: { + for existing_id in Self::get_similar(self.map.get_mut(&hash), &self.buf, hash) { + if existing_id != new_id && structually_eq(&self.buf, existing_id, new_id) { + break 'search Some(existing_id); + } + } + None + }; + if let Some(existing_id) = dup { + return Err(existing_id); + } + + self.add_mapping(hash, new_id); + Ok(new_id) + } + pub(crate) fn insert(&mut self, term: Term) -> TermId { // Checks existence. let hash = term_hash(&term); @@ -948,6 +975,74 @@ impl TermView<'_, T> { arity: self.arity(), } } + + pub(crate) fn deserialize(self) -> Term { + Term { + functor: self.functor().clone(), + args: self.args().map(Self::deserialize).collect(), + } + } +} + +impl TermView<'_, T> { + /// Returns true if this term is a variable. + /// + /// e.g. Terms like `X`, `Y` will return true. + pub(crate) fn is_variable(&self) -> bool { + let is_var = self.functor().is_variable(); + + #[cfg(debug_assertions)] + if is_var { + assert_eq!(self.arity(), 0); + } + + is_var + } + + /// Returns true if this term is a variable or contains variable in it. + /// + /// e.g. Terms like `X` of `f(X)` will return true. + pub(crate) fn contains_variable(&self) -> bool { + self.is_variable() || self.args().any(|arg| arg.contains_variable()) + } + + pub(crate) fn with_variable(&self, mut f: F) { + fn helper<'a, T, F>(view: &TermView<'a, T>, f: &mut F) + where + T: Atom, + F: FnMut(&TermView<'a, T>), + { + if view.is_variable() { + f(view); + } else { + for arg in view.args() { + helper(&arg, f); + } + } + } + helper(self, &mut f) + } + + pub(crate) fn collect_variables(&self) -> Vec { + let mut vars = Vec::new(); + self.with_variable(|var| { + if !vars.contains(&var.id) { + vars.push(var.id); + } + }); + vars + } +} + +impl TermView<'_, T> { + /// Returns the first variable functor in this term. + pub(crate) fn find_variable(&self) -> Option { + if self.is_variable() { + Some(*self.functor()) + } else { + self.args().find_map(|arg| arg.find_variable()) + } + } } #[derive(Clone)] @@ -997,7 +1092,7 @@ impl iter::FusedIterator for TermViewArgs<'_, T> {} #[derive(Debug, Clone)] pub struct TermDeepView<'a, T> { pub(crate) buf: &'a [TermElem], - pub(crate) links: &'a [usize], + pub(crate) term_assigns: &'a TermAssignments, pub(crate) id: TermId, } @@ -1027,26 +1122,18 @@ impl<'a, T> TermDeepView<'a, T> { let end = start + view.arity() as usize; TermDeepViewArgs { buf: view.buf, - links: view.links, + term_assigns: view.term_assigns, start, end, } } pub(crate) fn jump(&self) -> Self { - let mut i = self.id.0; - - while let Some(next) = self.links.get(i) { - if i == *next { - break; - } - i = *next; - } - + let root = self.term_assigns.find(self.id).unwrap_or(self.id); Self { buf: self.buf, - links: self.links, - id: TermId(i), + term_assigns: self.term_assigns, + id: root, } } } @@ -1054,7 +1141,7 @@ impl<'a, T> TermDeepView<'a, T> { #[derive(Clone)] pub(crate) struct TermDeepViewArgs<'a, T> { buf: &'a [TermElem], - links: &'a [usize], + term_assigns: &'a TermAssignments, /// Inclusive start: TermId, /// Exclusive @@ -1078,7 +1165,7 @@ impl<'a, T> Iterator for TermDeepViewArgs<'a, T> { self.start += 1; Some(TermDeepView { buf: self.buf, - links: self.links, + term_assigns: self.term_assigns, id, }) } else { @@ -1106,6 +1193,13 @@ pub(crate) struct TermViewMut<'a, T> { } impl<'a, T> TermViewMut<'a, T> { + pub(crate) fn as_view(&self) -> TermView<'_, T> { + TermView { + buf: &self.arr.buf, + id: self.id, + } + } + pub(crate) const fn id(&self) -> TermId { self.id } @@ -1124,30 +1218,52 @@ impl<'a, T> TermViewMut<'a, T> { n } - pub(crate) fn with_terminal(&mut self, f: &mut F) + pub(crate) fn with_terminal(&mut self, mut f: F) where F: FnMut(&mut UniqueTermArray, TermId), { - let arity = self.arity(); - if arity == 0 { - f(self.arr, self.id); - } else { - let org = self.id.0; + fn helper<'a, T, F>(this: &mut TermViewMut<'a, T>, f: &mut F) + where + F: FnMut(&mut UniqueTermArray, TermId), + { + let arity = this.arity(); + if arity == 0 { + f(this.arr, this.id); + } else { + let org = this.id.0; - for i in 0..arity as usize { - let TermElem::Arg(arg_id) = self.arr.buf[org + 2 + i] else { - unreachable!() - }; - self.id = arg_id; - self.with_terminal(f); - } + for i in 0..arity as usize { + let TermElem::Arg(arg_id) = this.arr.buf[org + 2 + i] else { + unreachable!() + }; + this.id = arg_id; + helper(this, f); + } - self.id = TermId(org); + this.id = TermId(org); + } } + helper(self, &mut f) } } -impl TermViewMut<'_, T> { +impl TermViewMut<'_, T> { + /// Applies the given function to all functors in this term, then returns true if any of + /// functors has been replaced. + /// + /// Original term never changes, instead this method makes a new term when required. If new term + /// has been generated, this view becomes to point the new term. + pub(crate) fn replace_with Option>(&mut self, mut map_func: F) -> bool { + let mut buf_off = self.arr.buf.len(); + let mut early_exit = |_: TermId| None; + if let Some(new_id) = self.replace_inner(&mut buf_off, &mut map_func, &mut early_exit) { + self.id = new_id; + true + } else { + false + } + } + /// If this term is `from` or contains `from`, then replaces them to `to` in a clone-on-write /// way. /// @@ -1159,7 +1275,9 @@ impl TermViewMut<'_, T> { } let mut buf_off = self.arr.buf.len(); - if let Some(new_id) = self._replace(from, to, &mut buf_off) { + let mut map_func = |_: &T| None; + let mut early_exit = |id: TermId| (id == from).then_some(to); + if let Some(new_id) = self.replace_inner(&mut buf_off, &mut map_func, &mut early_exit) { self.id = new_id; true } else { @@ -1167,9 +1285,24 @@ impl TermViewMut<'_, T> { } } - fn _replace(&mut self, from: TermId, to: TermId, buf_off: &mut usize) -> Option { - if self.id == from { - return Some(to); + /// Shared clone-on-write implementation used by [`Self::replace`] and [`Self::replace_with`]. + /// + /// * `map_func` - Optionally replaces the functor of the current term. + /// * `early_exit` - Optionally replaces the current term by id, short-circuiting before any + /// cloning. + fn replace_inner( + &mut self, + buf_off: &mut usize, + map_func: &mut MapFn, + early_exit: &mut MapId, + ) -> Option + where + MapFn: FnMut(&T) -> Option, + MapId: FnMut(TermId) -> Option, + { + // Short-circuit: return a direct replacement for this term without cloning. + if let Some(id) = early_exit(self.id) { + return Some(id); } // Term id & term space for this view. @@ -1177,15 +1310,14 @@ impl TermViewMut<'_, T> { // Term id & term space for a new term. let new = *buf_off; - // Reserves buffer space for a new term corresponding this view. - let TermElem::Arity(arity) = self.arr.buf[cur + 1] else { - unreachable!() - }; - *buf_off = new + 2 + arity as usize; - let org_buf_len = self.arr.buf.len(); - if self.arr.buf.len() < *buf_off { - self.arr.buf.resize_with(*buf_off, || TermElem::dummy()); - } + // Reserves buffer space for the size of this term if required. + let (arity, org_buf_len, new_end) = self.ensure_space(new); + + // Moves the offset for nested inner terms. + *buf_off = new_end; + + let new_functor = map_func(self.functor()); + let is_func_replaced = new_functor.is_some(); // Tries to replace the arguments. let mut is_arg_replaced = false; @@ -1195,7 +1327,8 @@ impl TermViewMut<'_, T> { }; self.id = arg_id; - self.arr.buf[new + 2 + i] = if let Some(new_arg_id) = self._replace(from, to, buf_off) { + let new_arg_id = self.replace_inner(buf_off, map_func, early_exit); + self.arr.buf[new + 2 + i] = if let Some(new_arg_id) = new_arg_id { is_arg_replaced = true; TermElem::Arg(new_arg_id) } else { @@ -1204,26 +1337,51 @@ impl TermViewMut<'_, T> { } self.id = TermId(cur); - if is_arg_replaced { - // Sets the functor and arity at the new space. - let TermElem::Functor(functor) = &self.arr.buf[cur] else { - unreachable!() + if is_func_replaced || is_arg_replaced { + // Sets the functor/arity at the new space. + self.arr.buf[new] = match new_functor { + Some(f) => TermElem::Functor(f), + None => TermElem::Functor(self.functor().clone()), }; - self.arr.buf[new] = TermElem::Functor(functor.clone()); self.arr.buf[new + 1] = TermElem::Arity(arity); - let new_id = TermId(new); - // New mapping. - let hash = buf_term_hash(&self.arr.buf, new_id); - self.arr.add_mapping(hash, new_id); - - Some(new_id) + // Dedup: Reuse an existing structually identical term if present, otherwise register + // the new one. + match self.arr.try_register_unique_term(TermId(new)) { + Ok(id) => Some(id), + Err(existing_id) => { + self.arr.buf.truncate(org_buf_len); // Discards the buffer change + Some(existing_id) + } + } } else { - // Discards the buffer change. - self.arr.buf.truncate(org_buf_len); + self.arr.buf.truncate(org_buf_len); // Discards the buffer change None } } + + /// Reserves term space as much as this term in the term array if required. + /// + /// Returns + /// - arity + /// - old term array length + /// - `off + 2 + arity`, which means the end index of the reserved term space. + fn ensure_space(&mut self, off: usize) -> (u32, usize, usize) { + let cur = self.id.0; + let TermElem::Arity(arity) = self.arr.buf[cur + 1] else { + unreachable!() + }; + + let term_space = 2 /* functor + arity itself */ + arity as usize; + let end = off + term_space; + + let old_len = self.arr.buf.len(); + if old_len < end { + self.arr.buf.resize_with(end, || TermElem::dummy()); + } + + (arity, old_len, end) + } } /// Element representing a part of a term in a unified buffer. @@ -1282,6 +1440,40 @@ fn term_hash(term: &Term) -> u64 { } } +/// Returns true if the two terms at `a` and `b` are structurally identical. +fn structually_eq(buf: &[TermElem], a: TermId, b: TermId) -> bool { + if a == b { + return true; + } + let TermElem::Functor(fa) = &buf[a.0] else { + return false; + }; + let TermElem::Functor(fb) = &buf[b.0] else { + return false; + }; + if fa != fb { + return false; + } + let TermElem::Arity(na) = buf[a.0 + 1] else { + return false; + }; + let TermElem::Arity(nb) = buf[b.0 + 1] else { + return false; + }; + if na != nb { + return false; + } + (0..na as usize).all(|i| { + let TermElem::Arg(arg_a) = buf[a.0 + 2 + i] else { + return false; + }; + let TermElem::Arg(arg_b) = buf[b.0 + 2 + i] else { + return false; + }; + structually_eq(buf, arg_a, arg_b) + }) +} + /// Generates the same hash value as what [`term_hash`] generates. fn buf_term_hash(buf: &[TermElem], id: TermId) -> u64 { // A hasher with fixed keys @@ -1320,11 +1512,6 @@ mod tests { use any_intern::DroplessInterner; #[test] - fn test_expr_array() { - test_expr_array_replace_term(); - test_expr_array_replace_expr(); - } - fn test_expr_array_replace_term() { let mut buf = TermStorage::new(); let interner = DroplessInterner::default(); @@ -1378,6 +1565,7 @@ mod tests { assert_eq!(buf.exprs.buf, expected_buf); } + #[test] fn test_expr_array_replace_expr() { let mut buf = TermStorage::new(); let interner = DroplessInterner::default(); @@ -1423,11 +1611,6 @@ mod tests { } #[test] - fn test_term_array() { - test_term_array_replace(); - test_recursive_term(); - } - fn test_term_array_replace() { let mut arr = UniqueTermArray::new(); let interner = DroplessInterner::default(); @@ -1484,6 +1667,65 @@ mod tests { assert_eq!(arr.buf, expected_buf); } + #[test] + #[rustfmt::skip] + fn test_term_array_replace_with() { + let mut arr = UniqueTermArray::new(); + let interner = DroplessInterner::default(); + + let id_f = insert_term(&mut arr, &interner, "f($X, $Y, $X)"); + + let mut expected_buf: Vec>> = vec![ + /* 0 */ TermElem::Functor(Name::with_intern("f", &interner)), + /* 1 */ TermElem::Arity(3), + /* 2 */ TermElem::Arg(TermId(5)), + /* 3 */ TermElem::Arg(TermId(7)), + /* 4 */ TermElem::Arg(TermId(5)), // X appears twice, but stored once + /* 5 */ TermElem::Functor(Name::with_intern("$X", &interner)), + /* 6 */ TermElem::Arity(0), + /* 7 */ TermElem::Functor(Name::with_intern("$Y", &interner)), + /* 8 */ TermElem::Arity(0), + ]; + + assert_eq!(arr.buf, expected_buf); + assert_eq!(id_f, TermId(0)); + + // === Replace === + + let mut view = arr.get_mut(id_f); + let replaced = view.replace_with(|functor| { + let interned = match functor.as_ref() { + "$X" => interner.intern_formatted_str(&0, 1).unwrap(), + "$Y" => interner.intern_formatted_str(&1, 1).unwrap(), + _ => return None, + }; + Some(Name::new(interned)) + }); + assert!(replaced); + + let clone_on_replace: Vec>> = vec![ + /* 9 */ TermElem::Functor(Name::with_intern("f", &interner)), + /* 10 */ TermElem::Arity(3), + /* 11 */ TermElem::Arg(TermId(14)), + /* 12 */ TermElem::Arg(TermId(16)), + /* 13 */ TermElem::Arg(TermId(14)), // reuses first "0" instead of a duplicate + /* 14 */ TermElem::Functor(Name::with_intern("0", &interner)), + /* 15 */ TermElem::Arity(0), + /* 16 */ TermElem::Functor(Name::with_intern("1", &interner)), + /* 17 */ TermElem::Arity(0), + ]; + expected_buf.extend(clone_on_replace); + + // The view now points to the newly created f(0, 1, 0). + assert_eq!(view.id(), TermId(9)); + assert_eq!(arr.buf, expected_buf); + // The original f(X, Y, X) at id=0 is untouched. + assert_eq!(arr.buf[0], TermElem::Functor(Name::with_intern("f", &interner))); + assert_eq!(arr.buf[2], TermElem::Arg(TermId(5))); + assert_eq!(arr.buf[4], TermElem::Arg(TermId(5))); // both X args still point to original X + } + + #[test] fn test_recursive_term() { let mut arr = UniqueTermArray::new(); let interner = DroplessInterner::default(); diff --git a/crates/logic-eval/src/prove/table.rs b/crates/logic-eval/src/prove/table.rs new file mode 100644 index 0000000..e615988 --- /dev/null +++ b/crates/logic-eval/src/prove/table.rs @@ -0,0 +1,215 @@ +use super::{ + canonical::CanonicalTermId, + prover::Integer, + repr::{TermId, TermView}, +}; +use crate::{prove::prover::TermAssignments, Map}; +use core::ops::{Index, IndexMut}; + +#[derive(Debug, Default)] +pub(crate) struct Table { + indices: Map, + entries: Vec, +} + +impl Table { + pub(crate) fn clear(&mut self) { + self.indices.clear(); + self.entries.clear(); + } + + pub(crate) fn get_mut( + &mut self, + key: &CanonicalTermId, + ) -> Option<(TableIndex, &mut TableEntry)> { + self.indices + .get(key) + .map(|&i| (TableIndex(i), &mut self.entries[i])) + } + + pub(crate) fn register(&mut self, key: CanonicalTermId, entry: TableEntry) -> TableIndex { + if let Some(table_index) = self.indices.get(&key) { + return TableIndex(*table_index); + } + + let table_index = self.entries.len(); + self.indices.insert(key, table_index); + self.entries.push(entry); + TableIndex(table_index) + } +} + +impl Index for Table { + type Output = TableEntry; + + fn index(&self, index: TableIndex) -> &Self::Output { + &self.entries[index.0] + } +} + +impl IndexMut for Table { + fn index_mut(&mut self, index: TableIndex) -> &mut Self::Output { + &mut self.entries[index.0] + } +} + +#[derive(Debug)] +pub(crate) struct TableEntry { + /// Non-empty assignment record for variables in a node. + /// + /// All [`SeenAssignments`] have the same length of answers. Use the same index to get a set of + /// answers. + /// + /// # Examples + /// + /// If we have 3 answers like below, the answers for (X, Y) are only (a, i), (b, j) or (c, k). + /// Other combinations are invalid. + /// X = a or b or c + /// Y = i or j or k + seen: AnswerMatrix, + + /// Consumers(nodes) will be notified when their entry has updated. + consumers: Vec, +} + +impl TableEntry { + /// Making an entry can be rejected when + /// - `view` is just a variable, which doesn't make sense for tabling. f(X) should be given for + /// example. + /// - `view` doesn't contain any variables in it. It doesn't need the tabling. + pub(crate) fn from_term_view(view: &TermView<'_, Integer>) -> Option { + if view.is_variable() || !view.contains_variable() { + return None; + } + + let mut vars = Vec::new(); + for arg in view.args() { + arg.with_variable(|var| { + if !vars.contains(&var.id) { + vars.push(var.id); + } + }); + } + + Some(Self { + seen: AnswerMatrix::with_variables(vars), + consumers: Vec::new(), + }) + } + + /// See [`AnswerMatrix::update`]. + pub(crate) fn update_answer(&mut self, term_assigns: &TermAssignments) { + self.seen.update(term_assigns); + } + + pub(crate) fn has_answer(&self, term_assigns: &TermAssignments) -> bool { + self.seen.has_answer(term_assigns) + } + + pub(crate) fn consumer_nodes(&self) -> impl Iterator + '_ { + self.consumers.iter().map(|c| c.node_index) + } + + pub(crate) fn variables(&self) -> &[TermId] { + self.seen.column(0) + } + + /// An empty slice is returned when the `answer_index` is out of bounds. + pub(crate) fn answers(&self, answer_index: usize) -> &[TermId] { + let col = answer_index + 1; + self.seen.column(col) + } + + pub(crate) fn register_consumer(&mut self, node_index: usize) { + if self + .consumers + .iter() + .all(|consumer| consumer.node_index != node_index) + { + self.consumers.push(Consumer { node_index }); + } + } +} + +/// A non-empty variable-answer relations. +/// +/// unique var | answer1 | answer2 | assign3 | +/// :----------: | :-----: | :-----: | :-----: | +/// X | a | x | i | +/// Y | b | y | j | +/// W | c | z | k | +/// Z | d | w | l | +#[derive(Debug)] +pub(crate) struct AnswerMatrix { + /// Column-wise elements, e.g. X, Y, W, Z, a, b, c, d, ... + elems: Vec, + rows: usize, + // For double check + cols: usize, +} + +impl AnswerMatrix { + fn with_variables(vars: Vec) -> Self { + let rows = vars.len(); + debug_assert!(rows > 0); + + for i in 0..rows { + for j in i + 1..rows { + debug_assert_ne!(vars[i], vars[j]); + } + } + + Self { + elems: vars, + rows, + cols: 1, + } + } + + /// An empty slice is returned when the `col` is out of bounds. + fn column(&self, col: usize) -> &[TermId] { + let start = col * self.rows; + let end = start + self.rows; + + if end <= self.elems.len() { + &self.elems[start..end] + } else { + &[] + } + } + + /// This method assumes that the `term_assigns` has concrete answers(atoms, not variables) for + /// variables of this entry. + fn update(&mut self, term_assigns: &TermAssignments) { + self.elems.reserve_exact(self.rows); + for r in 0..self.rows { + let var = self.elems[r]; + let answer = term_assigns.find(var).unwrap(); + self.elems.push(answer); + } + self.cols += 1; + } + + fn has_answer(&self, term_assigns: &TermAssignments) -> bool { + let vars = self.column(0); + for col_idx in 1..self.cols { + let answers = self.column(col_idx); + if vars + .iter() + .zip(answers) + .all(|(var, answer)| term_assigns.find(*var) == Some(*answer)) + { + return true; + } + } + false + } +} + +#[derive(Debug)] +struct Consumer { + node_index: usize, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub(crate) struct TableIndex(usize); From 8e4df388a14bc48590a6ebd1aac9ac3171d51118 Mon Sep 17 00:00:00 2001 From: ecoricemon Date: Wed, 25 Mar 2026 14:52:43 +0900 Subject: [PATCH 2/2] clippy --- crates/logic-eval/src/prove/db.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/crates/logic-eval/src/prove/db.rs b/crates/logic-eval/src/prove/db.rs index b92762c..6c7159c 100644 --- a/crates/logic-eval/src/prove/db.rs +++ b/crates/logic-eval/src/prove/db.rs @@ -381,15 +381,13 @@ impl DuplicateClauseChecker { let canonical_clause = clause.map(&mut |t| { if !t.is_variable() { t + } else if let Some(found) = self.vars.iter().find(|&&var| var == t) { + *found } else { - if let Some(found) = self.vars.iter().find(|&&var| var == t) { - *found - } else { - let next_int = self.vars.len() as u32; - let int = Integer::variable(next_int); - self.vars.push(int); - int - } + let next_int = self.vars.len() as u32; + let int = Integer::variable(next_int); + self.vars.push(int); + int } }); let is_new = self.seen.insert(canonical_clause);