diff --git a/Cargo.toml b/Cargo.toml index 1be00cb..d4bf35a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ pgn-reader = "0.26.0" nom = "8.0" rayon = "1.10" num_cpus = "1.17" +arrow-array = "55" +pyo3-arrow = "0.10" [dev-dependencies] criterion = "0.6" diff --git a/src/bench_parquet_arrow.py b/src/bench_parquet_arrow.py new file mode 100644 index 0000000..221982c --- /dev/null +++ b/src/bench_parquet_arrow.py @@ -0,0 +1,20 @@ +import rust_pgn_reader_python_binding +import pyarrow.parquet as pq + +from datetime import datetime + +file_path = "2013-07-train-00000-of-00001.parquet" + + +pf = pq.ParquetFile(file_path) + +movetext_arrow_array = pf.read(columns=["movetext"]).column("movetext") + +a = datetime.now() + +extractors = rust_pgn_reader_python_binding.parse_games_arrow_chunked_array( + movetext_arrow_array +) + +b = datetime.now() +print(b - a) diff --git a/src/lib.rs b/src/lib.rs index 30b11f0..f290104 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,69 @@ use crate::comment_parsing::parse_comments; use crate::comment_parsing::CommentContent; +use arrow_array::{Array, LargeStringArray, StringArray}; use pgn_reader::{BufferedReader, RawComment, RawHeader, SanPlus, Skip, Visitor}; use pyo3::prelude::*; +use pyo3_arrow::PyChunkedArray; use rayon::prelude::*; use rayon::ThreadPoolBuilder; use shakmaty::Color; -use shakmaty::{uci::UciMove, Chess, Outcome, Position}; +use shakmaty::{uci::UciMove, Chess, Outcome, Position, Square}; use std::io::Cursor; mod comment_parsing; +// Definition of PyUciMove +#[pyclass(get_all, set_all, module = "rust_pgn_reader_python_binding")] +#[derive(Clone, Debug)] +pub struct PyUciMove { + pub from_square: u8, + pub to_square: u8, + pub promotion: Option, +} + +#[pymethods] +impl PyUciMove { + #[new] + fn new(from_square: u8, to_square: u8, promotion: Option) -> Self { + PyUciMove { + from_square, + to_square, + promotion, + } + } + + #[getter] + fn get_from_square_name(&self) -> String { + Square::new(self.from_square as u32).to_string() + } + + #[getter] + fn get_to_square_name(&self) -> String { + Square::new(self.to_square as u32).to_string() + } + + // __str__ method for Python representation + fn __str__(&self) -> String { + let promo_char = self.promotion.map_or("".to_string(), |p| p.to_string()); + format!( + "{}{}{}", + Square::new(self.from_square as u32), + Square::new(self.to_square as u32), + promo_char + ) + } + + // __repr__ for a more developer-friendly representation + fn __repr__(&self) -> String { + format!( + "PyUciMove(from_square={}, to_square={}, promotion={:?})", + Square::new(self.from_square as u32), + Square::new(self.to_square as u32), + self.promotion + ) + } +} + #[pyclass] /// Holds the status of a chess position. #[derive(Clone)] @@ -32,11 +86,12 @@ pub struct PositionStatus { #[pyo3(get)] turn: bool, } + #[pyclass] /// A Visitor to extract SAN moves and comments from PGN movetext pub struct MoveExtractor { #[pyo3(get)] - moves: Vec, + moves: Vec, #[pyo3(get)] valid_moves: bool, @@ -146,9 +201,30 @@ impl Visitor for MoveExtractor { match san_plus.san.to_move(&self.pos) { Ok(m) => { self.pos.play_unchecked(&m); - let uci = UciMove::from_standard(&m); - self.moves.push(uci.to_string()); - self.push_castling_bitboards(); + let uci_move_obj = UciMove::from_standard(&m); + + match uci_move_obj { + UciMove::Normal { + from, + to, + promotion: promo_opt, + } => { + let py_uci_move = PyUciMove { + from_square: from as u8, + to_square: to as u8, + promotion: promo_opt.map(|p| p.char()), + }; + self.moves.push(py_uci_move); + self.push_castling_bitboards(); + } + _ => { + // This case handles UciMove::Put and UciMove::Null, + // which are not expected from standard PGN moves + // that PyUciMove is designed to represent. + eprintln!("Unexpected UCI move type from standard PGN move: {:?}. Game moves might be invalid.", uci_move_obj); + self.valid_moves = false; + } + } } Err(err) => { eprintln!("error in game: {} {}", err, san_plus); @@ -263,7 +339,53 @@ pub fn parse_multiple_games_native( }) } +fn _parse_games_from_arrow_chunks_native( + pgn_chunked_array: &PyChunkedArray, + num_threads: Option, +) -> Result, String> { + let num_threads = num_threads.unwrap_or_else(|| num_cpus::get()); + let thread_pool = ThreadPoolBuilder::new() + .num_threads(num_threads) + .build() + .map_err(|e| format!("Failed to build Rayon thread pool: {}", e))?; + + let mut num_elements = 0; + for chunk in pgn_chunked_array.chunks() { + num_elements += chunk.len(); + } + let mut pgn_str_slices: Vec<&str> = Vec::with_capacity(num_elements); + for chunk in pgn_chunked_array.chunks() { + if let Some(string_array) = chunk.as_any().downcast_ref::() { + for i in 0..string_array.len() { + if string_array.is_valid(i) { + pgn_str_slices.push(string_array.value(i)); + } + } + } else if let Some(large_string_array) = chunk.as_any().downcast_ref::() { + for i in 0..large_string_array.len() { + if large_string_array.is_valid(i) { + pgn_str_slices.push(large_string_array.value(i)); + } + } + } else { + return Err(format!( + "Unsupported array type in ChunkedArray: {:?}", + chunk.data_type() + )); + } + } + + thread_pool.install(|| { + pgn_str_slices + .par_iter() + .map(|&pgn_s| parse_single_game_native(pgn_s)) + .collect::, String>>() + }) +} + // --- Python-facing wrappers (PyResult) --- +// TODO check if I can call py.allow_threads and release GIL +// see https://docs.rs/pyo3-arrow/0.10.1/pyo3_arrow/ #[pyfunction] /// Parses a single PGN game string. fn parse_game(pgn: &str) -> PyResult { @@ -278,12 +400,24 @@ fn parse_games(pgns: Vec, num_threads: Option) -> PyResult, +) -> PyResult> { + _parse_games_from_arrow_chunks_native(&pgn_chunked_array, num_threads) + .map_err(|err| pyo3::exceptions::PyValueError::new_err(err)) +} + /// Parser for chess PGN notation #[pymodule] fn rust_pgn_reader_python_binding(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(parse_game, m)?)?; m.add_function(wrap_pyfunction!(parse_games, m)?)?; + m.add_function(wrap_pyfunction!(parse_games_arrow_chunked_array, m)?)?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) }