diff --git a/benches/parquet_bench.rs b/benches/parquet_bench.rs index 9471e67..fb69596 100644 --- a/benches/parquet_bench.rs +++ b/benches/parquet_bench.rs @@ -43,7 +43,7 @@ pub fn bench_parquet() { // Measure start time let start = Instant::now(); - let result = parse_multiple_games_native(&movetexts, None); + let result = parse_multiple_games_native(&movetexts, None, false); let duration = start.elapsed(); println!("Time taken: {:?}", duration); diff --git a/rust_pgn_reader_python_binding.pyi b/rust_pgn_reader_python_binding.pyi index 90c571c..c92ce8d 100644 --- a/rust_pgn_reader_python_binding.pyi +++ b/rust_pgn_reader_python_binding.pyi @@ -37,13 +37,15 @@ class MoveExtractor: castling_rights: List[Tuple[bool, bool, bool, bool]] position_status: Optional[PositionStatus] - def __init__(self) -> None: ... + def __init__(self, store_legal_moves: bool = False) -> None: ... def turn(self) -> bool: ... -def parse_game(pgn: str) -> MoveExtractor: ... +def parse_game(pgn: str, store_legal_moves: bool = False) -> MoveExtractor: ... def parse_games( - pgns: List[str], num_threads: Optional[int] = None + pgns: List[str], num_threads: Optional[int] = None, store_legal_moves: bool = False ) -> List[MoveExtractor]: ... def parse_game_moves_arrow_chunked_array( - pgn_chunked_array: pyarrow.ChunkedArray, num_threads: Optional[int] = None + pgn_chunked_array: pyarrow.ChunkedArray, + num_threads: Optional[int] = None, + store_legal_moves: bool = False, ) -> List[MoveExtractor]: ... diff --git a/src/lib.rs b/src/lib.rs index 0bb8f17..e589a80 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -111,6 +111,10 @@ pub struct MoveExtractor { #[pyo3(get)] moves: Vec, + store_legal_moves: bool, + flat_legal_moves: Vec, + legal_moves_offsets: Vec, + #[pyo3(get)] valid_moves: bool, @@ -141,9 +145,13 @@ pub struct MoveExtractor { #[pymethods] impl MoveExtractor { #[new] - fn new() -> MoveExtractor { + #[pyo3(signature = (store_legal_moves = false))] + fn new(store_legal_moves: bool) -> MoveExtractor { MoveExtractor { moves: Vec::with_capacity(100), + store_legal_moves, + flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), // Pre-allocate for moves + legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), // Pre-allocate for offsets pos: Chess::default(), valid_moves: true, comments: Vec::with_capacity(100), @@ -175,6 +183,30 @@ impl MoveExtractor { self.castling_rights.push(Some(castling_rights)); } + fn push_legal_moves(&mut self) { + // Record the starting offset for the current position's legal moves. + self.legal_moves_offsets.push(self.flat_legal_moves.len()); + + let legal_moves_for_pos = self.pos.legal_moves(); + self.flat_legal_moves.reserve(legal_moves_for_pos.len()); + + for m in legal_moves_for_pos { + let uci_move_obj = UciMove::from_standard(m); + if let UciMove::Normal { + from, + to, + promotion: promo_opt, + } = uci_move_obj + { + self.flat_legal_moves.push(PyUciMove { + from_square: from as u8, + to_square: to as u8, + promotion: promo_opt.map(|p_role| p_role as u8), + }); + } + } + } + fn update_position_status(&mut self) { // TODO this checks legal_moves() a bunch of times self.position_status = Some(PositionStatus { @@ -192,6 +224,27 @@ impl MoveExtractor { }, }); } + + #[getter] + fn legal_moves(&self) -> Vec> { + let mut result = Vec::with_capacity(self.legal_moves_offsets.len()); + if self.legal_moves_offsets.is_empty() { + return result; + } + + for i in 0..self.legal_moves_offsets.len() - 1 { + let start = self.legal_moves_offsets[i]; + let end = self.legal_moves_offsets[i + 1]; + result.push(self.flat_legal_moves[start..end].to_vec()); + } + + // Handle the last chunk + if let Some(&start) = self.legal_moves_offsets.last() { + result.push(self.flat_legal_moves[start..].to_vec()); + } + + result + } } impl Visitor for MoveExtractor { @@ -219,6 +272,8 @@ impl Visitor for MoveExtractor { fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow { self.headers = tags; self.moves.clear(); + self.flat_legal_moves.clear(); + self.legal_moves_offsets.clear(); self.pos = Chess::default(); self.valid_moves = true; self.comments.clear(); @@ -227,6 +282,9 @@ impl Visitor for MoveExtractor { self.castling_rights.clear(); self.push_castling_bitboards(); + if self.store_legal_moves { + self.push_legal_moves(); + } ControlFlow::Continue(()) } @@ -241,6 +299,9 @@ impl Visitor for MoveExtractor { match san_plus.san.to_move(&self.pos) { Ok(m) => { self.pos.play_unchecked(m); + if self.store_legal_moves { + self.push_legal_moves(); + } let uci_move_obj = UciMove::from_standard(m); match uci_move_obj { @@ -372,9 +433,12 @@ impl Visitor for MoveExtractor { } // --- Native Rust versions (no PyResult) --- -pub fn parse_single_game_native(pgn: &str) -> Result { +pub fn parse_single_game_native( + pgn: &str, + store_legal_moves: bool, +) -> Result { let mut reader = Reader::new(Cursor::new(pgn)); - let mut extractor = MoveExtractor::new(); + let mut extractor = MoveExtractor::new(store_legal_moves); match reader.read_game(&mut extractor) { Ok(Some(_)) => Ok(extractor), Ok(None) => Err("No game found in PGN".to_string()), @@ -385,6 +449,7 @@ pub fn parse_single_game_native(pgn: &str) -> Result { pub fn parse_multiple_games_native( pgns: &Vec, num_threads: Option, + store_legal_moves: bool, ) -> Result, String> { let num_threads = num_threads.unwrap_or_else(num_cpus::get); @@ -396,7 +461,7 @@ pub fn parse_multiple_games_native( thread_pool.install(|| { pgns.par_iter() - .map(|pgn| parse_single_game_native(pgn)) + .map(|pgn| parse_single_game_native(pgn, store_legal_moves)) .collect() }) } @@ -404,6 +469,7 @@ pub fn parse_multiple_games_native( fn _parse_game_moves_from_arrow_chunks_native( pgn_chunked_array: &PyChunkedArray, num_threads: Option, + store_legal_moves: bool, ) -> Result, String> { let num_threads = num_threads.unwrap_or_else(num_cpus::get); let thread_pool = ThreadPoolBuilder::new() @@ -440,7 +506,7 @@ fn _parse_game_moves_from_arrow_chunks_native( thread_pool.install(|| { pgn_str_slices .par_iter() - .map(|&pgn_s| parse_single_game_native(pgn_s)) + .map(|&pgn_s| parse_single_game_native(pgn_s, store_legal_moves)) .collect::, String>>() }) } @@ -449,25 +515,33 @@ fn _parse_game_moves_from_arrow_chunks_native( // TODO check if I can call py.allow_threads and release GIL // see https://docs.rs/pyo3-arrow/0.10.1/pyo3_arrow/ #[pyfunction] +#[pyo3(signature = (pgn, store_legal_moves = false))] /// Parses a single PGN game string. -fn parse_game(pgn: &str) -> PyResult { - parse_single_game_native(pgn).map_err(pyo3::exceptions::PyValueError::new_err) +fn parse_game(pgn: &str, store_legal_moves: bool) -> PyResult { + parse_single_game_native(pgn, store_legal_moves) + .map_err(pyo3::exceptions::PyValueError::new_err) } /// In parallel, parse a set of games #[pyfunction] -#[pyo3(signature = (pgns, num_threads=None))] -fn parse_games(pgns: Vec, num_threads: Option) -> PyResult> { - parse_multiple_games_native(&pgns, num_threads).map_err(pyo3::exceptions::PyValueError::new_err) +#[pyo3(signature = (pgns, num_threads=None, store_legal_moves=false))] +fn parse_games( + pgns: Vec, + num_threads: Option, + store_legal_moves: bool, +) -> PyResult> { + parse_multiple_games_native(&pgns, num_threads, store_legal_moves) + .map_err(pyo3::exceptions::PyValueError::new_err) } #[pyfunction] -#[pyo3(signature = (pgn_chunked_array, num_threads=None))] +#[pyo3(signature = (pgn_chunked_array, num_threads=None, store_legal_moves=false))] fn parse_game_moves_arrow_chunked_array( pgn_chunked_array: PyChunkedArray, num_threads: Option, + store_legal_moves: bool, ) -> PyResult> { - _parse_game_moves_from_arrow_chunks_native(&pgn_chunked_array, num_threads) + _parse_game_moves_from_arrow_chunks_native(&pgn_chunked_array, num_threads, store_legal_moves) .map_err(pyo3::exceptions::PyValueError::new_err) }