From ed1f816e85cb56557bb3c4b637c95d9096acf6b3 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Sun, 13 Jul 2025 20:09:16 -0400 Subject: [PATCH 1/4] Add legal moves per position --- src/lib.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 0bb8f17..29606f5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -111,6 +111,9 @@ pub struct MoveExtractor { #[pyo3(get)] moves: Vec, + #[pyo3(get)] + legal_moves: Vec>, + #[pyo3(get)] valid_moves: bool, @@ -144,6 +147,7 @@ impl MoveExtractor { fn new() -> MoveExtractor { MoveExtractor { moves: Vec::with_capacity(100), + legal_moves: Vec::with_capacity(100), pos: Chess::default(), valid_moves: true, comments: Vec::with_capacity(100), @@ -219,6 +223,7 @@ impl Visitor for MoveExtractor { fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { self.headers = tags; self.moves.clear(); + self.legal_moves.clear(); self.pos = Chess::default(); self.valid_moves = true; self.comments.clear(); From fc3bf5ddb72ca3a6dfe92d0c2865fe9370e16a0e Mon Sep 17 00:00:00 2001 From: vladkvit Date: Tue, 15 Jul 2025 18:45:42 -0400 Subject: [PATCH 2/4] Fix bug with end state not being captured by legal_moves --- src/lib.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 29606f5..cfe194f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -179,6 +179,36 @@ impl MoveExtractor { self.castling_rights.push(Some(castling_rights)); } + fn push_legal_moves(&mut self) { + let legal_moves_for_pos = self.pos.legal_moves(); + let mut uci_moves_for_pos: Vec = Vec::with_capacity(legal_moves_for_pos.len()); + + for m in legal_moves_for_pos { + let uci_move_obj = UciMove::from_standard(m); + match uci_move_obj { + UciMove::Normal { + from, + to, + promotion: promo_opt, + } => { + let py_uci_move = PyUciMove { + from_square: from as u8, + to_square: to as u8, + promotion: promo_opt.map(|p_role| p_role as u8), + }; + uci_moves_for_pos.push(py_uci_move); + } + _ => { + eprintln!( + "Unexpected UCI move type from standard PGN move: {:?}. Game moves might be invalid.", + uci_move_obj + ); + } + } + } + self.legal_moves.push(uci_moves_for_pos); + } + fn update_position_status(&mut self) { // TODO this checks legal_moves() a bunch of times self.position_status = Some(PositionStatus { @@ -232,6 +262,7 @@ impl Visitor for MoveExtractor { self.castling_rights.clear(); self.push_castling_bitboards(); + self.push_legal_moves(); ControlFlow::Continue(()) } @@ -246,6 +277,7 @@ impl Visitor for MoveExtractor { match san_plus.san.to_move(&self.pos) { Ok(m) => { self.pos.play_unchecked(m); + self.push_legal_moves(); let uci_move_obj = UciMove::from_standard(m); match uci_move_obj { From d72a0819c823adfb852da560546621bf91fd7546 Mon Sep 17 00:00:00 2001 From: vladkvit Date: Sat, 9 Aug 2025 16:06:26 -0400 Subject: [PATCH 3/4] Flatten the legal moves vec --- src/lib.rs | 67 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index cfe194f..558ffa3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -111,8 +111,8 @@ pub struct MoveExtractor { #[pyo3(get)] moves: Vec, - #[pyo3(get)] - legal_moves: Vec>, + flat_legal_moves: Vec, + legal_moves_offsets: Vec, #[pyo3(get)] valid_moves: bool, @@ -147,7 +147,8 @@ impl MoveExtractor { fn new() -> MoveExtractor { MoveExtractor { moves: Vec::with_capacity(100), - legal_moves: Vec::with_capacity(100), + flat_legal_moves: Vec::with_capacity(100 * 30), // Pre-allocate for moves + legal_moves_offsets: Vec::with_capacity(100), // Pre-allocate for offsets pos: Chess::default(), valid_moves: true, comments: Vec::with_capacity(100), @@ -180,33 +181,27 @@ impl MoveExtractor { } fn push_legal_moves(&mut self) { + // Record the starting offset for the current position's legal moves. + self.legal_moves_offsets.push(self.flat_legal_moves.len()); + let legal_moves_for_pos = self.pos.legal_moves(); - let mut uci_moves_for_pos: Vec = Vec::with_capacity(legal_moves_for_pos.len()); + self.flat_legal_moves.reserve(legal_moves_for_pos.len()); for m in legal_moves_for_pos { let uci_move_obj = UciMove::from_standard(m); - match uci_move_obj { - UciMove::Normal { - from, - to, - promotion: promo_opt, - } => { - let py_uci_move = PyUciMove { - from_square: from as u8, - to_square: to as u8, - promotion: promo_opt.map(|p_role| p_role as u8), - }; - uci_moves_for_pos.push(py_uci_move); - } - _ => { - eprintln!( - "Unexpected UCI move type from standard PGN move: {:?}. Game moves might be invalid.", - uci_move_obj - ); - } + if let UciMove::Normal { + from, + to, + promotion: promo_opt, + } = uci_move_obj + { + self.flat_legal_moves.push(PyUciMove { + from_square: from as u8, + to_square: to as u8, + promotion: promo_opt.map(|p_role| p_role as u8), + }); } } - self.legal_moves.push(uci_moves_for_pos); } fn update_position_status(&mut self) { @@ -226,6 +221,27 @@ impl MoveExtractor { }, }); } + + #[getter] + fn legal_moves(&self) -> Vec> { + let mut result = Vec::with_capacity(self.legal_moves_offsets.len()); + if self.legal_moves_offsets.is_empty() { + return result; + } + + for i in 0..self.legal_moves_offsets.len() - 1 { + let start = self.legal_moves_offsets[i]; + let end = self.legal_moves_offsets[i + 1]; + result.push(self.flat_legal_moves[start..end].to_vec()); + } + + // Handle the last chunk + if let Some(&start) = self.legal_moves_offsets.last() { + result.push(self.flat_legal_moves[start..].to_vec()); + } + + result + } } impl Visitor for MoveExtractor { @@ -253,7 +269,8 @@ impl Visitor for MoveExtractor { fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { self.headers = tags; self.moves.clear(); - self.legal_moves.clear(); + self.flat_legal_moves.clear(); + self.legal_moves_offsets.clear(); self.pos = Chess::default(); self.valid_moves = true; self.comments.clear(); From a5a406db1291d457af402fe9b1a416ba69bde5ce Mon Sep 17 00:00:00 2001 From: vladkvit Date: Wed, 13 Aug 2025 14:23:11 -0400 Subject: [PATCH 4/4] Gate legal move storage behind a flag --- benches/parquet_bench.rs | 2 +- rust_pgn_reader_python_binding.pyi | 10 +++--- src/lib.rs | 52 +++++++++++++++++++++--------- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index 9471e67..fb69596 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -43,7 +43,7 @@ pub fn bench_parquet() { // Measure start time let start = Instant::now(); - let result = parse_multiple_games_native(&movetexts, None); + let result = parse_multiple_games_native(&movetexts, None, false); let duration = start.elapsed(); println!("Time taken: {:?}", duration); diff --git a/rust_pgn_reader_python_binding.pyi b/rust_pgn_reader_python_binding.pyi index 90c571c..c92ce8d 100644 --- a/rust_pgn_reader_python_binding.pyi +++ b/rust_pgn_reader_python_binding.pyi @@ -37,13 +37,15 @@ class MoveExtractor: castling_rights: List[Tuple[bool, bool, bool, bool]] position_status: Optional[PositionStatus] - def __init__(self) -> None: ... + def __init__(self, store_legal_moves: bool = False) -> None: ... def turn(self) -> bool: ... -def parse_game(pgn: str) -> MoveExtractor: ... +def parse_game(pgn: str, store_legal_moves: bool = False) -> MoveExtractor: ... def parse_games( - pgns: List[str], num_threads: Optional[int] = None + pgns: List[str], num_threads: Optional[int] = None, store_legal_moves: bool = False ) -> List[MoveExtractor]: ... def parse_game_moves_arrow_chunked_array( - pgn_chunked_array: pyarrow.ChunkedArray, num_threads: Optional[int] = None + pgn_chunked_array: pyarrow.ChunkedArray, + num_threads: Optional[int] = None, + store_legal_moves: bool = False, ) -> List[MoveExtractor]: ... diff --git a/src/lib.rs b/src/lib.rs index 558ffa3..e589a80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -111,6 +111,7 @@ pub struct MoveExtractor { #[pyo3(get)] moves: Vec, + store_legal_moves: bool, flat_legal_moves: Vec, legal_moves_offsets: Vec, @@ -144,11 +145,13 @@ pub struct MoveExtractor { #[pymethods] impl MoveExtractor { #[new] - fn new() -> MoveExtractor { + #[pyo3(signature = (store_legal_moves = false))] + fn new(store_legal_moves: bool) -> MoveExtractor { MoveExtractor { moves: Vec::with_capacity(100), - flat_legal_moves: Vec::with_capacity(100 * 30), // Pre-allocate for moves - legal_moves_offsets: Vec::with_capacity(100), // Pre-allocate for offsets + store_legal_moves, + flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), // Pre-allocate for moves + legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), // Pre-allocate for offsets pos: Chess::default(), valid_moves: true, comments: Vec::with_capacity(100), @@ -279,7 +282,9 @@ impl Visitor for MoveExtractor { self.castling_rights.clear(); self.push_castling_bitboards(); - self.push_legal_moves(); + if self.store_legal_moves { + self.push_legal_moves(); + } ControlFlow::Continue(()) } @@ -294,7 +299,9 @@ impl Visitor for MoveExtractor { match san_plus.san.to_move(&self.pos) { Ok(m) => { self.pos.play_unchecked(m); - self.push_legal_moves(); + if self.store_legal_moves { + self.push_legal_moves(); + } let uci_move_obj = UciMove::from_standard(m); match uci_move_obj { @@ -426,9 +433,12 @@ impl Visitor for MoveExtractor { } // --- Native Rust versions (no PyResult) --- -pub fn parse_single_game_native(pgn: &str) -> Result { +pub fn parse_single_game_native( + pgn: &str, + store_legal_moves: bool, +) -> Result { let mut reader = Reader::new(Cursor::new(pgn)); - let mut extractor = MoveExtractor::new(); + let mut extractor = MoveExtractor::new(store_legal_moves); match reader.read_game(&mut extractor) { Ok(Some(_)) => Ok(extractor), Ok(None) => Err("No game found in PGN".to_string()), @@ -439,6 +449,7 @@ pub fn parse_single_game_native(pgn: &str) -> Result { pub fn parse_multiple_games_native( pgns: &Vec, num_threads: Option, + store_legal_moves: bool, ) -> Result, String> { let num_threads = num_threads.unwrap_or_else(num_cpus::get); @@ -450,7 +461,7 @@ pub fn parse_multiple_games_native( thread_pool.install(|| { pgns.par_iter() - .map(|pgn| parse_single_game_native(pgn)) + .map(|pgn| parse_single_game_native(pgn, store_legal_moves)) .collect() }) } @@ -458,6 +469,7 @@ pub fn parse_multiple_games_native( fn _parse_game_moves_from_arrow_chunks_native( pgn_chunked_array: &PyChunkedArray, num_threads: Option, + store_legal_moves: bool, ) -> Result, String> { let num_threads = num_threads.unwrap_or_else(num_cpus::get); let thread_pool = ThreadPoolBuilder::new() @@ -494,7 +506,7 @@ fn _parse_game_moves_from_arrow_chunks_native( thread_pool.install(|| { pgn_str_slices .par_iter() - .map(|&pgn_s| parse_single_game_native(pgn_s)) + .map(|&pgn_s| parse_single_game_native(pgn_s, store_legal_moves)) .collect::, String>>() }) } @@ -503,25 +515,33 @@ fn _parse_game_moves_from_arrow_chunks_native( // TODO check if I can call py.allow_threads and release GIL // see https://docs.rs/pyo3-arrow/0.10.1/pyo3_arrow/ #[pyfunction] +#[pyo3(signature = (pgn, store_legal_moves = false))] /// Parses a single PGN game string. -fn parse_game(pgn: &str) -> PyResult { - parse_single_game_native(pgn).map_err(pyo3::exceptions::PyValueError::new_err) +fn parse_game(pgn: &str, store_legal_moves: bool) -> PyResult { + parse_single_game_native(pgn, store_legal_moves) + .map_err(pyo3::exceptions::PyValueError::new_err) } /// In parallel, parse a set of games #[pyfunction] -#[pyo3(signature = (pgns, num_threads=None))] -fn parse_games(pgns: Vec, num_threads: Option) -> PyResult> { - parse_multiple_games_native(&pgns, num_threads).map_err(pyo3::exceptions::PyValueError::new_err) +#[pyo3(signature = (pgns, num_threads=None, store_legal_moves=false))] +fn parse_games( + pgns: Vec, + num_threads: Option, + store_legal_moves: bool, +) -> PyResult> { + parse_multiple_games_native(&pgns, num_threads, store_legal_moves) + .map_err(pyo3::exceptions::PyValueError::new_err) } #[pyfunction] -#[pyo3(signature = (pgn_chunked_array, num_threads=None))] +#[pyo3(signature = (pgn_chunked_array, num_threads=None, store_legal_moves=false))] fn parse_game_moves_arrow_chunked_array( pgn_chunked_array: PyChunkedArray, num_threads: Option, + store_legal_moves: bool, ) -> PyResult> { - _parse_game_moves_from_arrow_chunks_native(&pgn_chunked_array, num_threads) + _parse_game_moves_from_arrow_chunks_native(&pgn_chunked_array, num_threads, store_legal_moves) .map_err(pyo3::exceptions::PyValueError::new_err) }