From 7ac0a545d16e6d7b782d7ec86e7159ddb29eaa22 Mon Sep 17 00:00:00 2001 From: alderpath Date: Mon, 1 Jun 2026 12:49:53 +0100 Subject: [PATCH] Hardening: remove 21 DB unwraps, add schema versioning, canonicalize cache-dir, cap WalkDir at 500K, add concurrent build guard P0: All 21 DB .unwrap() calls replaced with graceful match+eprintln fallback. P1: PRAGMA user_version tracked; open_existing_db rejects mismatched schemas. P2: --cache-dir now canonicalized via std::fs::canonicalize. P3: WalkDir capped at 500,000 files with secondary truncate. P5: AtomicBool build_in_progress prevents concurrent index builds on switch_repo. 54/54 unit tests, 17/17 integration tests passing. --- src/index/mod.rs | 7 +- src/index/schema.rs | 15 ++- src/main.rs | 47 ++++++--- src/search/mod.rs | 153 +++++++++++++++------------ src/structural_risk.rs | 231 ++++++++++++++++++++++++----------------- 5 files changed, 276 insertions(+), 177 deletions(-) diff --git a/src/index/mod.rs b/src/index/mod.rs index b0d8cd0..23fc4d7 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -207,7 +207,6 @@ pub fn build_phrase_index(repo_path: &str, out_dir: &Path, verbose: bool) -> Res } else { let db = Connection::open(&db_path).map_err(|e| format!("db: {}", e))?; schema::open_existing_db(&db).map_err(|e| format!("pragmas: {}", e))?; - schema::create_new_db(&db).ok(); // create tables if they don't exist (db, false) }; @@ -757,6 +756,8 @@ fn collect_source_files(repo: &Path) -> Vec { ] .into(); + const MAX_FILES: usize = 500_000; + for entry in WalkDir::new(repo) .follow_links(false) .into_iter() @@ -775,6 +776,9 @@ fn collect_source_files(repo: &Path) -> Vec { if !entry.file_type().is_file() { continue; } + if files.len() >= MAX_FILES { + break; + } let fname = entry.file_name().to_string_lossy(); if lock_suffixes.contains(fname.as_ref()) { continue; @@ -798,6 +802,7 @@ fn collect_source_files(repo: &Path) -> Vec { } } files.sort(); + files.truncate(MAX_FILES); files } diff --git a/src/index/schema.rs b/src/index/schema.rs index 3a0f2d8..726f35d 100644 --- a/src/index/schema.rs +++ b/src/index/schema.rs @@ -1,5 +1,7 @@ use rusqlite::Connection; +pub(crate) const SCHEMA_VERSION: i32 = 1; + pub(crate) fn create_new_db(db: &Connection) -> rusqlite::Result<()> { db.execute_batch( "PRAGMA synchronous = OFF; @@ -9,7 +11,8 @@ pub(crate) fn create_new_db(db: &Connection) -> rusqlite::Result<()> { PRAGMA temp_store = MEMORY; PRAGMA lock_timeout = 5000;", )?; - create_tables(db) + create_tables(db)?; + db.execute_batch(&format!("PRAGMA user_version = {}", SCHEMA_VERSION)) } pub(crate) fn open_existing_db(db: &Connection) -> rusqlite::Result<()> { @@ -20,7 +23,15 @@ pub(crate) fn open_existing_db(db: &Connection) -> rusqlite::Result<()> { PRAGMA mmap_size = 268435456; PRAGMA temp_store = MEMORY; PRAGMA lock_timeout = 5000;", - ) + )?; + let version: i32 = db.query_row("PRAGMA user_version", [], |r| r.get(0))?; + if version != SCHEMA_VERSION { + return Err(rusqlite::Error::InvalidColumnName(format!( + "Schema version mismatch: DB has {}, stria expects {}. Run `stria build` to upgrade.", + version, SCHEMA_VERSION + ))); + } + Ok(()) } fn create_tables(db: &Connection) -> rusqlite::Result<()> { diff --git a/src/main.rs b/src/main.rs index b9b430a..123a0b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -95,7 +95,7 @@ fn main() { .unwrap_or_else(|_| repo_path.to_path_buf()); let out_dir = cache_dir .as_ref() - .map(|c| Path::new(c).to_path_buf()) + .map(|c| std::fs::canonicalize(c).unwrap_or_else(|_| Path::new(c).to_path_buf())) .unwrap_or_else(|| canonical.join(".stria")); if !out_dir.join("phrases.sqlite").exists() { match index::build_phrase_index( @@ -123,7 +123,7 @@ fn main() { .unwrap_or_else(|_| repo_path.to_path_buf()); let out_dir = cache_dir .as_ref() - .map(|c| Path::new(c).to_path_buf()) + .map(|c| std::fs::canonicalize(c).unwrap_or_else(|_| Path::new(c).to_path_buf())) .unwrap_or_else(|| canonical.join(".stria")); if !out_dir.join("phrases.sqlite").exists() { match index::build_phrase_index( @@ -155,16 +155,19 @@ fn main() { fn mcp_server(initial_repo: String, initial_out: &std::path::Path) { use serde_json::{json, Value}; use std::io::{self, BufRead, Write}; + use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Mutex; struct ServerState { repo: String, cache_dir: std::path::PathBuf, + build_in_progress: AtomicBool, } let state = Mutex::new(ServerState { repo: initial_repo, cache_dir: initial_out.to_path_buf(), + build_in_progress: AtomicBool::new(false), }); let db_path_of = |cache: &std::path::Path| -> String { @@ -578,17 +581,37 @@ fn mcp_server(initial_repo: String, initial_out: &std::path::Path) { } else { let out_dir = std::path::Path::new(&canonical).join(".stria"); if !out_dir.join("phrases.sqlite").exists() { - match index::build_phrase_index(&canonical, &out_dir, false) { - Ok(n) => { - let mut st = state.lock().unwrap(); - st.repo = canonical; - st.cache_dir = out_dir; - drop(st); - json!({"status": "ok", "phrases": n}) - } - Err(e) => { - json!({"error": format!("Index build failed: {}", e)}) + // Check if a build is already in progress + let already_building = { + let st = state.lock().unwrap(); + st.build_in_progress.load(Ordering::Relaxed) + }; + if already_building { + json!({"error": "Index build already in progress for another repo"}) + } else { + { + let st = state.lock().unwrap(); + st.build_in_progress.store(true, Ordering::Relaxed); } + let result = match index::build_phrase_index( + &canonical, &out_dir, false, + ) { + Ok(n) => { + let mut st = state.lock().unwrap(); + st.repo = canonical; + st.cache_dir = out_dir; + st.build_in_progress.store(false, Ordering::Relaxed); + drop(st); + json!({"status": "ok", "phrases": n}) + } + Err(e) => { + let st = state.lock().unwrap(); + st.build_in_progress.store(false, Ordering::Relaxed); + drop(st); + json!({"error": format!("Index build failed: {}", e)}) + } + }; + result } } else { let mut st = state.lock().unwrap(); diff --git a/src/search/mod.rs b/src/search/mod.rs index 8a47ba2..bcf8425 100644 --- a/src/search/mod.rs +++ b/src/search/mod.rs @@ -309,39 +309,44 @@ fn _search_phrases(db_path: &str, query: &str, top_n: usize) -> Vec<(String, f64 Ok(q) => q, Err(_) => continue, }; - let rows_result = exact_q - .query_map([st], |r| { - let fid = r.get::<_, i64>(0)?; - let flags = r.get::<_, Vec>(1)?; - let line_blob = r.get::<_, Vec>(2)?; - let doc_len = r.get::<_, f64>(3)?; - let uniq_def = r.get::<_, i64>(4)?; - let total_def = r.get::<_, i64>(5)?; - let comment_ratio = r.get::<_, f64>(6)?; - let overflow_count = r.get::<_, u32>(7)?; - let f = if !flags.is_empty() { flags[0] } else { 0 }; - let is_def = schema::unpack_is_def(f); - let zone_int = schema::unpack_zone_int(f); - let base_count = schema::unpack_count(f); - let tf = if base_count >= 31 { - overflow_count as f64 - } else { - base_count as f64 - }; - let first_line = schema::unpack_line_nos(&line_blob) as i32; - Ok(( - fid, - tf, - is_def, - zone_int, - doc_len, - uniq_def, - total_def, - comment_ratio, - first_line, - )) - }) - .unwrap(); + let rows_result = exact_q.query_map([st], |r| { + let fid = r.get::<_, i64>(0)?; + let flags = r.get::<_, Vec>(1)?; + let line_blob = r.get::<_, Vec>(2)?; + let doc_len = r.get::<_, f64>(3)?; + let uniq_def = r.get::<_, i64>(4)?; + let total_def = r.get::<_, i64>(5)?; + let comment_ratio = r.get::<_, f64>(6)?; + let overflow_count = r.get::<_, u32>(7)?; + let f = if !flags.is_empty() { flags[0] } else { 0 }; + let is_def = schema::unpack_is_def(f); + let zone_int = schema::unpack_zone_int(f); + let base_count = schema::unpack_count(f); + let tf = if base_count >= 31 { + overflow_count as f64 + } else { + base_count as f64 + }; + let first_line = schema::unpack_line_nos(&line_blob) as i32; + Ok(( + fid, + tf, + is_def, + zone_int, + doc_len, + uniq_def, + total_def, + comment_ratio, + first_line, + )) + }); + let rows_result = match rows_result { + Ok(rows) => rows, + Err(e) => { + eprintln!("eh:warn: search: exact query_map failed: {}", e); + continue; + } + }; for row in rows_result.filter_map(|r| r.ok()) { let ( fid, @@ -406,26 +411,31 @@ fn _search_phrases(db_path: &str, query: &str, top_n: usize) -> Vec<(String, f64 Err(_) => continue, }; let mut max_per_file: HashMap = HashMap::new(); - let prefix_rows = prefix_q - .query_map(params![&pattern, st], |r| { - let fid = r.get::<_, i64>(0)?; - let flags = r.get::<_, Vec>(1)?; - let doc_len = r.get::<_, f64>(2)?; - let uniq_def = r.get::<_, i64>(3)?; - let total_def = r.get::<_, i64>(4)?; - let comment_ratio = r.get::<_, f64>(5)?; - let f = if !flags.is_empty() { flags[0] } else { 0 }; - Ok(( - fid, - schema::unpack_is_def(f), - schema::unpack_zone_int(f), - doc_len, - uniq_def, - total_def, - comment_ratio, - )) - }) - .unwrap(); + let prefix_result = prefix_q.query_map(params![&pattern, st], |r| { + let fid = r.get::<_, i64>(0)?; + let flags = r.get::<_, Vec>(1)?; + let doc_len = r.get::<_, f64>(2)?; + let uniq_def = r.get::<_, i64>(3)?; + let total_def = r.get::<_, i64>(4)?; + let comment_ratio = r.get::<_, f64>(5)?; + let f = if !flags.is_empty() { flags[0] } else { 0 }; + Ok(( + fid, + schema::unpack_is_def(f), + schema::unpack_zone_int(f), + doc_len, + uniq_def, + total_def, + comment_ratio, + )) + }); + let prefix_rows = match prefix_result { + Ok(rows) => rows, + Err(e) => { + eprintln!("eh:warn: search: prefix query_map failed: {}", e); + continue; + } + }; for row in prefix_rows.filter_map(|r| r.ok()) { let (fid, is_def, zone_int, doc_len, uniq_def, total_def, comment_ratio) = row; let tf = 1.0; @@ -479,22 +489,27 @@ fn _search_phrases(db_path: &str, query: &str, top_n: usize) -> Vec<(String, f64 Err(_) => continue, }; let mut max_per_file: HashMap = HashMap::new(); - let sub_rows = sub_q - .query_map(params![&pattern, &excl_prefix, st], |r| { - let fid = r.get::<_, i64>(0)?; - let flags = r.get::<_, Vec>(1)?; - let doc_len = r.get::<_, f64>(2)?; - let comment_ratio = r.get::<_, f64>(3)?; - let f = if !flags.is_empty() { flags[0] } else { 0 }; - Ok(( - fid, - schema::unpack_is_def(f), - schema::unpack_zone_int(f), - doc_len, - comment_ratio, - )) - }) - .unwrap(); + let sub_result = sub_q.query_map(params![&pattern, &excl_prefix, st], |r| { + let fid = r.get::<_, i64>(0)?; + let flags = r.get::<_, Vec>(1)?; + let doc_len = r.get::<_, f64>(2)?; + let comment_ratio = r.get::<_, f64>(3)?; + let f = if !flags.is_empty() { flags[0] } else { 0 }; + Ok(( + fid, + schema::unpack_is_def(f), + schema::unpack_zone_int(f), + doc_len, + comment_ratio, + )) + }); + let sub_rows = match sub_result { + Ok(rows) => rows, + Err(e) => { + eprintln!("eh:warn: search: substring query_map failed: {}", e); + continue; + } + }; for row in sub_rows.filter_map(|r| r.ok()) { let (fid, is_def, zone_int, doc_len, comment_ratio) = row; let tf = 1.0; diff --git a/src/structural_risk.rs b/src/structural_risk.rs index 508a1c5..dda3ad0 100644 --- a/src/structural_risk.rs +++ b/src/structural_risk.rs @@ -82,7 +82,6 @@ pub fn who_calls(db_path: &str, name: &str) -> Vec<(String, f64)> { } } - db.close().ok(); results } @@ -144,11 +143,12 @@ pub fn latent_deps(db_path: &str, file: &str) -> Vec<(String, f64)> { ) .ok(); - if fid.is_none() { - db.close().ok(); - return vec![]; - } - let fid = fid.unwrap(); + let fid = match fid { + Some(f) => f, + None => { + return vec![]; + } + }; let fp = file.to_string(); let module = std::path::Path::new(&fp) .parent() @@ -156,49 +156,64 @@ pub fn latent_deps(db_path: &str, file: &str) -> Vec<(String, f64)> { .unwrap_or_default(); // Find rare phrases (df <= 3) that this file defines — unpack is_def in Rust - let mut rare_q = db - .prepare( - "SELECT p.phrase, po.flags FROM phrase_occ po + let mut rare_q = match db.prepare( + "SELECT p.phrase, po.flags FROM phrase_occ po JOIN phrases p ON p.id = po.phrase_id WHERE po.file_id = ?1 AND (SELECT COUNT(*) FROM phrase_occ po2 WHERE po2.phrase_id = po.phrase_id) <= 3 LIMIT 50", - ) - .unwrap(); - let rare_phrases: Vec = rare_q - .query_map([fid], |r| { - let phrase = r.get::<_, String>(0)?; - let flags = r.get::<_, Vec>(1)?; - let f = if !flags.is_empty() { flags[0] } else { 0 }; - Ok((phrase, schema::unpack_is_def(f))) - }) - .unwrap() - .filter_map(|r| r.ok()) - .filter(|(_, is_def)| *is_def == 1) - .map(|(phrase, _)| phrase) - .collect(); + ) { + Ok(q) => q, + Err(e) => { + eprintln!("eh:warn: latent_deps: prepare rare_q failed: {}", e); + return vec![]; + } + }; + let rare_phrases: Vec = match rare_q.query_map([fid], |r| { + let phrase = r.get::<_, String>(0)?; + let flags = r.get::<_, Vec>(1)?; + let f = if !flags.is_empty() { flags[0] } else { 0 }; + Ok((phrase, schema::unpack_is_def(f))) + }) { + Ok(rows) => rows + .filter_map(|r| r.ok()) + .filter(|(_, is_def)| *is_def == 1) + .map(|(phrase, _)| phrase) + .collect(), + Err(e) => { + eprintln!("eh:warn: latent_deps: rare_q query failed: {}", e); + return vec![]; + } + }; drop(rare_q); if rare_phrases.is_empty() { - db.close().ok(); return vec![]; } let mut score_map: HashMap = HashMap::new(); for phrase in &rare_phrases { - let mut stmt = db - .prepare( - "SELECT fm.file_path + let mut stmt = match db.prepare( + "SELECT fm.file_path FROM phrase_occ po JOIN phrases p ON p.id = po.phrase_id JOIN file_map fm ON fm.id = po.file_id WHERE p.phrase = ?1 AND po.file_id != ?2 LIMIT 10", - ) - .unwrap(); - let rows = stmt - .query_map(rusqlite::params![phrase, fid], |r| r.get::<_, String>(0)) - .unwrap(); + ) { + Ok(s) => s, + Err(e) => { + eprintln!("eh:warn: latent_deps: inner prepare failed: {}", e); + continue; + } + }; + let rows = match stmt.query_map(rusqlite::params![phrase, fid], |r| r.get::<_, String>(0)) { + Ok(rows) => rows, + Err(e) => { + eprintln!("eh:warn: latent_deps: inner query failed: {}", e); + continue; + } + }; for row in rows.flatten() { let other_module = std::path::Path::new(&row) .parent() @@ -225,30 +240,38 @@ pub fn blast_radius(db_path: &str, file: &str) -> Vec<(String, f64)> { }; // Get distinctive phrases (is_def > 0) — unpack in Rust - let mut phrase_q = db - .prepare( - "SELECT p.phrase, po.flags, - COALESCE(oc.count, 1) as effective_count + let mut phrase_q = match db.prepare( + "SELECT p.phrase, po.flags, + COALESCE(oc.count, 1) as effective_count FROM phrase_occ po JOIN phrases p ON p.id = po.phrase_id LEFT JOIN count_overflow oc ON oc.phrase_id = po.phrase_id AND oc.file_id = po.file_id WHERE po.file_id=?1 LIMIT 50", - ) - .unwrap(); - let phrases: Vec<(String, f64)> = phrase_q - .query_map([fid], |r| { - let phrase = r.get::<_, String>(0)?; - let flags = r.get::<_, Vec>(1)?; - let count = r.get::<_, f64>(2)?; - let f = if !flags.is_empty() { flags[0] } else { 0 }; - let is_def = schema::unpack_is_def(f); - Ok((phrase, count, is_def)) - }) - .unwrap() - .filter_map(|r| r.ok()) - .filter(|(_, _, is_def)| *is_def > 0) - .map(|(phrase, count, _)| (phrase, count)) - .collect(); + ) { + Ok(q) => q, + Err(e) => { + eprintln!("eh:warn: blast_radius: prepare failed: {}", e); + return vec![]; + } + }; + let phrases: Vec<(String, f64)> = match phrase_q.query_map([fid], |r| { + let phrase = r.get::<_, String>(0)?; + let flags = r.get::<_, Vec>(1)?; + let count = r.get::<_, f64>(2)?; + let f = if !flags.is_empty() { flags[0] } else { 0 }; + let is_def = schema::unpack_is_def(f); + Ok((phrase, count, is_def)) + }) { + Ok(rows) => rows + .filter_map(|r| r.ok()) + .filter(|(_, _, is_def)| *is_def > 0) + .map(|(phrase, count, _)| (phrase, count)) + .collect(), + Err(e) => { + eprintln!("eh:warn: blast_radius: query failed: {}", e); + return vec![]; + } + }; drop(phrase_q); if phrases.is_empty() { @@ -257,18 +280,24 @@ pub fn blast_radius(db_path: &str, file: &str) -> Vec<(String, f64)> { let mut other_scores: HashMap = HashMap::new(); for (phrase, count) in &phrases { - let mut q = db - .prepare( - "SELECT po.file_id FROM phrase_occ po + let mut q = match db.prepare( + "SELECT po.file_id FROM phrase_occ po JOIN phrases p ON p.id = po.phrase_id WHERE p.phrase=?1 AND po.file_id!=?2 LIMIT 10", - ) - .unwrap(); - let rows: Vec = q - .query_map(params![phrase, fid], |r| r.get::<_, i64>(0)) - .unwrap() - .filter_map(|r| r.ok()) - .collect(); + ) { + Ok(q) => q, + Err(e) => { + eprintln!("eh:warn: blast_radius: inner prepare failed: {}", e); + continue; + } + }; + let rows: Vec = match q.query_map(params![phrase, fid], |r| r.get::<_, i64>(0)) { + Ok(rows) => rows.filter_map(|r| r.ok()).collect(), + Err(e) => { + eprintln!("eh:warn: blast_radius: inner query failed: {}", e); + continue; + } + }; drop(q); for ofid in rows { *other_scores.entry(ofid).or_insert(0.0) += count; @@ -325,12 +354,13 @@ pub fn build_file_index(db_path: &str) -> Vec { Ok(s) => s, Err(_) => return vec![], }; - let files: Vec = stmt - .query_map([], |r| r.get::<_, String>(0)) - .unwrap() - .filter_map(|r| r.ok()) - .collect(); - drop(stmt); + let files: Vec = match stmt.query_map([], |r| r.get::<_, String>(0)) { + Ok(rows) => rows.filter_map(|r| r.ok()).collect(), + Err(e) => { + eprintln!("eh:warn: build_file_index: query failed: {}", e); + return vec![]; + } + }; // db closes on drop files } @@ -422,7 +452,7 @@ pub fn find_verify_candidates(db_path: &str, file: &str) -> Vec<(String, f64)> { }; // Get overlapping phrases with is_def>0 from test files — unpack in Rust - let mut stmt = db.prepare( + let mut stmt = match db.prepare( "SELECT fm.file_path, po2.flags FROM phrase_occ po1 JOIN phrase_occ po2 ON po2.phrase_id = po1.phrase_id AND po2.file_id != po1.file_id @@ -430,17 +460,27 @@ pub fn find_verify_candidates(db_path: &str, file: &str) -> Vec<(String, f64)> { WHERE po1.file_id = ?1 AND (fm.file_path LIKE '%test%' OR fm.file_path LIKE '%spec%' OR fm.file_path LIKE '%__tests__%') LIMIT 500" - ).unwrap(); + ) { + Ok(s) => s, + Err(e) => { + eprintln!("eh:warn: find_verify_candidates: prepare failed: {}", e); + return vec![]; + } + }; let mut overlap_map: HashMap = HashMap::new(); - let rows = stmt - .query_map([fid], |r| { - let fp = r.get::<_, String>(0)?; - let flags = r.get::<_, Vec>(1)?; - let f = if !flags.is_empty() { flags[0] } else { 0 }; - Ok((fp, schema::unpack_is_def(f))) - }) - .unwrap(); + let rows = match stmt.query_map([fid], |r| { + let fp = r.get::<_, String>(0)?; + let flags = r.get::<_, Vec>(1)?; + let f = if !flags.is_empty() { flags[0] } else { 0 }; + Ok((fp, schema::unpack_is_def(f))) + }) { + Ok(rows) => rows, + Err(e) => { + eprintln!("eh:warn: find_verify_candidates: query failed: {}", e); + return vec![]; + } + }; for row in rows.filter_map(|r| r.ok()) { let (fp, is_def) = row; if is_def > 0 { @@ -480,21 +520,26 @@ pub fn hologram_plan(db_path: &str, task: &str) -> serde_json::Value { WHERE p.phrase = ?1 LIMIT 50", ) { - let rows = stmt - .query_map([st], |r| { - let fp = r.get::<_, String>(0)?; - let flags = r.get::<_, Vec>(1)?; - let doc_len = r.get::<_, f64>(2)?; - let f = if !flags.is_empty() { flags[0] } else { 0 }; - let base_count = schema::unpack_count(f); - let tf = if base_count >= 31 { - 1.0 - } else { - base_count as f64 - }; - Ok((fp, tf, schema::unpack_is_def(f), doc_len)) - }) - .unwrap(); + let rows_result = stmt.query_map([st], |r| { + let fp = r.get::<_, String>(0)?; + let flags = r.get::<_, Vec>(1)?; + let doc_len = r.get::<_, f64>(2)?; + let f = if !flags.is_empty() { flags[0] } else { 0 }; + let base_count = schema::unpack_count(f); + let tf = if base_count >= 31 { + 1.0 + } else { + base_count as f64 + }; + Ok((fp, tf, schema::unpack_is_def(f), doc_len)) + }); + let rows = match rows_result { + Ok(rows) => rows, + Err(e) => { + eprintln!("eh:warn: hologram_plan: query_map failed: {}", e); + continue; + } + }; for row in rows.flatten() { let (fp, tf, is_def, doc_len) = row; let idf = crate::search::bm25::bm25_idf(n_docs, 5.0);