Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pgn-reader = "0.26.0"
nom = "8.0"
rayon = "1.10"
num_cpus = "1.17"
arrow-array = "55"
pyo3-arrow = "0.10"

[dev-dependencies]
criterion = "0.6"
Expand Down
20 changes: 20 additions & 0 deletions src/bench_parquet_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import rust_pgn_reader_python_binding
import pyarrow.parquet as pq

from datetime import datetime

file_path = "2013-07-train-00000-of-00001.parquet"


pf = pq.ParquetFile(file_path)

movetext_arrow_array = pf.read(columns=["movetext"]).column("movetext")

a = datetime.now()

extractors = rust_pgn_reader_python_binding.parse_games_arrow_chunked_array(
movetext_arrow_array
)

b = datetime.now()
print(b - a)
144 changes: 139 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,69 @@
use crate::comment_parsing::parse_comments;
use crate::comment_parsing::CommentContent;
use arrow_array::{Array, LargeStringArray, StringArray};
use pgn_reader::{BufferedReader, RawComment, RawHeader, SanPlus, Skip, Visitor};
use pyo3::prelude::*;
use pyo3_arrow::PyChunkedArray;
use rayon::prelude::*;
use rayon::ThreadPoolBuilder;
use shakmaty::Color;
use shakmaty::{uci::UciMove, Chess, Outcome, Position};
use shakmaty::{uci::UciMove, Chess, Outcome, Position, Square};
use std::io::Cursor;

mod comment_parsing;

// Definition of PyUciMove
#[pyclass(get_all, set_all, module = "rust_pgn_reader_python_binding")]
#[derive(Clone, Debug)]
pub struct PyUciMove {
pub from_square: u8,
pub to_square: u8,
pub promotion: Option<char>,
}

#[pymethods]
impl PyUciMove {
#[new]
fn new(from_square: u8, to_square: u8, promotion: Option<char>) -> Self {
PyUciMove {
from_square,
to_square,
promotion,
}
}

#[getter]
fn get_from_square_name(&self) -> String {
Square::new(self.from_square as u32).to_string()
}

#[getter]
fn get_to_square_name(&self) -> String {
Square::new(self.to_square as u32).to_string()
}

// __str__ method for Python representation
fn __str__(&self) -> String {
let promo_char = self.promotion.map_or("".to_string(), |p| p.to_string());
format!(
"{}{}{}",
Square::new(self.from_square as u32),
Square::new(self.to_square as u32),
promo_char
)
}

// __repr__ for a more developer-friendly representation
fn __repr__(&self) -> String {
format!(
"PyUciMove(from_square={}, to_square={}, promotion={:?})",
Square::new(self.from_square as u32),
Square::new(self.to_square as u32),
self.promotion
)
}
}

#[pyclass]
/// Holds the status of a chess position.
#[derive(Clone)]
Expand All @@ -32,11 +86,12 @@ pub struct PositionStatus {
#[pyo3(get)]
turn: bool,
}

#[pyclass]
/// A Visitor to extract SAN moves and comments from PGN movetext
pub struct MoveExtractor {
#[pyo3(get)]
moves: Vec<String>,
moves: Vec<PyUciMove>,

#[pyo3(get)]
valid_moves: bool,
Expand Down Expand Up @@ -146,9 +201,30 @@ impl Visitor for MoveExtractor {
match san_plus.san.to_move(&self.pos) {
Ok(m) => {
self.pos.play_unchecked(&m);
let uci = UciMove::from_standard(&m);
self.moves.push(uci.to_string());
self.push_castling_bitboards();
let uci_move_obj = UciMove::from_standard(&m);

match uci_move_obj {
UciMove::Normal {
from,
to,
promotion: promo_opt,
} => {
let py_uci_move = PyUciMove {
from_square: from as u8,
to_square: to as u8,
promotion: promo_opt.map(|p| p.char()),
};
self.moves.push(py_uci_move);
self.push_castling_bitboards();
}
_ => {
// This case handles UciMove::Put and UciMove::Null,
// which are not expected from standard PGN moves
// that PyUciMove is designed to represent.
eprintln!("Unexpected UCI move type from standard PGN move: {:?}. Game moves might be invalid.", uci_move_obj);
self.valid_moves = false;
}
}
}
Err(err) => {
eprintln!("error in game: {} {}", err, san_plus);
Expand Down Expand Up @@ -263,7 +339,53 @@ pub fn parse_multiple_games_native(
})
}

fn _parse_games_from_arrow_chunks_native(
pgn_chunked_array: &PyChunkedArray,
num_threads: Option<usize>,
) -> Result<Vec<MoveExtractor>, String> {
let num_threads = num_threads.unwrap_or_else(|| num_cpus::get());
let thread_pool = ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.map_err(|e| format!("Failed to build Rayon thread pool: {}", e))?;

let mut num_elements = 0;
for chunk in pgn_chunked_array.chunks() {
num_elements += chunk.len();
}
let mut pgn_str_slices: Vec<&str> = Vec::with_capacity(num_elements);
for chunk in pgn_chunked_array.chunks() {
if let Some(string_array) = chunk.as_any().downcast_ref::<StringArray>() {
for i in 0..string_array.len() {
if string_array.is_valid(i) {
pgn_str_slices.push(string_array.value(i));
}
}
} else if let Some(large_string_array) = chunk.as_any().downcast_ref::<LargeStringArray>() {
for i in 0..large_string_array.len() {
if large_string_array.is_valid(i) {
pgn_str_slices.push(large_string_array.value(i));
}
}
} else {
return Err(format!(
"Unsupported array type in ChunkedArray: {:?}",
chunk.data_type()
));
}
}

thread_pool.install(|| {
pgn_str_slices
.par_iter()
.map(|&pgn_s| parse_single_game_native(pgn_s))
.collect::<Result<Vec<MoveExtractor>, String>>()
})
}

// --- Python-facing wrappers (PyResult) ---
// TODO check if I can call py.allow_threads and release GIL
// see https://docs.rs/pyo3-arrow/0.10.1/pyo3_arrow/
#[pyfunction]
/// Parses a single PGN game string.
fn parse_game(pgn: &str) -> PyResult<MoveExtractor> {
Expand All @@ -278,12 +400,24 @@ fn parse_games(pgns: Vec<String>, num_threads: Option<usize>) -> PyResult<Vec<Mo
.map_err(|err| pyo3::exceptions::PyValueError::new_err(err))
}

#[pyfunction]
#[pyo3(signature = (pgn_chunked_array, num_threads=None))]
fn parse_games_arrow_chunked_array(
pgn_chunked_array: PyChunkedArray,
num_threads: Option<usize>,
) -> PyResult<Vec<MoveExtractor>> {
_parse_games_from_arrow_chunks_native(&pgn_chunked_array, num_threads)
.map_err(|err| pyo3::exceptions::PyValueError::new_err(err))
}

/// Parser for chess PGN notation
#[pymodule]
fn rust_pgn_reader_python_binding(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(parse_game, m)?)?;
m.add_function(wrap_pyfunction!(parse_games, m)?)?;
m.add_function(wrap_pyfunction!(parse_games_arrow_chunked_array, m)?)?;
m.add_class::<MoveExtractor>()?;
m.add_class::<PositionStatus>()?;
m.add_class::<PyUciMove>()?;
Ok(())
}