Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/parquet_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub fn bench_parquet() {
// Measure start time
let start = Instant::now();

let result = parse_multiple_games_native(&movetexts, None);
let result = parse_multiple_games_native(&movetexts, None, false);

let duration = start.elapsed();
println!("Time taken: {:?}", duration);
Expand Down
10 changes: 6 additions & 4 deletions rust_pgn_reader_python_binding.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ class MoveExtractor:
castling_rights: List[Tuple[bool, bool, bool, bool]]
position_status: Optional[PositionStatus]

def __init__(self) -> None: ...
def __init__(self, store_legal_moves: bool = False) -> None: ...
def turn(self) -> bool: ...

def parse_game(pgn: str) -> MoveExtractor: ...
def parse_game(pgn: str, store_legal_moves: bool = False) -> MoveExtractor: ...
def parse_games(
pgns: List[str], num_threads: Optional[int] = None
pgns: List[str], num_threads: Optional[int] = None, store_legal_moves: bool = False
) -> List[MoveExtractor]: ...
def parse_game_moves_arrow_chunked_array(
pgn_chunked_array: pyarrow.ChunkedArray, num_threads: Optional[int] = None
pgn_chunked_array: pyarrow.ChunkedArray,
num_threads: Optional[int] = None,
store_legal_moves: bool = False,
) -> List[MoveExtractor]: ...
98 changes: 86 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ pub struct MoveExtractor {
#[pyo3(get)]
moves: Vec<PyUciMove>,

store_legal_moves: bool,
flat_legal_moves: Vec<PyUciMove>,
legal_moves_offsets: Vec<usize>,

#[pyo3(get)]
valid_moves: bool,

Expand Down Expand Up @@ -141,9 +145,13 @@ pub struct MoveExtractor {
#[pymethods]
impl MoveExtractor {
#[new]
fn new() -> MoveExtractor {
#[pyo3(signature = (store_legal_moves = false))]
fn new(store_legal_moves: bool) -> MoveExtractor {
MoveExtractor {
moves: Vec::with_capacity(100),
store_legal_moves,
flat_legal_moves: Vec::with_capacity(if store_legal_moves { 100 * 30 } else { 0 }), // Pre-allocate for moves
legal_moves_offsets: Vec::with_capacity(if store_legal_moves { 100 } else { 0 }), // Pre-allocate for offsets
pos: Chess::default(),
valid_moves: true,
comments: Vec::with_capacity(100),
Expand Down Expand Up @@ -175,6 +183,30 @@ impl MoveExtractor {
self.castling_rights.push(Some(castling_rights));
}

fn push_legal_moves(&mut self) {
// Record the starting offset for the current position's legal moves.
self.legal_moves_offsets.push(self.flat_legal_moves.len());

let legal_moves_for_pos = self.pos.legal_moves();
self.flat_legal_moves.reserve(legal_moves_for_pos.len());

for m in legal_moves_for_pos {
let uci_move_obj = UciMove::from_standard(m);
if let UciMove::Normal {
from,
to,
promotion: promo_opt,
} = uci_move_obj
{
self.flat_legal_moves.push(PyUciMove {
from_square: from as u8,
to_square: to as u8,
promotion: promo_opt.map(|p_role| p_role as u8),
});
}
}
}

fn update_position_status(&mut self) {
// TODO this checks legal_moves() a bunch of times
self.position_status = Some(PositionStatus {
Expand All @@ -192,6 +224,27 @@ impl MoveExtractor {
},
});
}

#[getter]
fn legal_moves(&self) -> Vec<Vec<PyUciMove>> {
let mut result = Vec::with_capacity(self.legal_moves_offsets.len());
if self.legal_moves_offsets.is_empty() {
return result;
}

for i in 0..self.legal_moves_offsets.len() - 1 {
let start = self.legal_moves_offsets[i];
let end = self.legal_moves_offsets[i + 1];
result.push(self.flat_legal_moves[start..end].to_vec());
}

// Handle the last chunk
if let Some(&start) = self.legal_moves_offsets.last() {
result.push(self.flat_legal_moves[start..].to_vec());
}

result
}
}

impl Visitor for MoveExtractor {
Expand Down Expand Up @@ -219,6 +272,8 @@ impl Visitor for MoveExtractor {
fn begin_movetext(&mut self, tags: Self::Tags) -> ControlFlow<Self::Output, Self::Movetext> {
self.headers = tags;
self.moves.clear();
self.flat_legal_moves.clear();
self.legal_moves_offsets.clear();
self.pos = Chess::default();
self.valid_moves = true;
self.comments.clear();
Expand All @@ -227,6 +282,9 @@ impl Visitor for MoveExtractor {
self.castling_rights.clear();

self.push_castling_bitboards();
if self.store_legal_moves {
self.push_legal_moves();
}
ControlFlow::Continue(())
}

Expand All @@ -241,6 +299,9 @@ impl Visitor for MoveExtractor {
match san_plus.san.to_move(&self.pos) {
Ok(m) => {
self.pos.play_unchecked(m);
if self.store_legal_moves {
self.push_legal_moves();
}
let uci_move_obj = UciMove::from_standard(m);

match uci_move_obj {
Expand Down Expand Up @@ -372,9 +433,12 @@ impl Visitor for MoveExtractor {
}

// --- Native Rust versions (no PyResult) ---
pub fn parse_single_game_native(pgn: &str) -> Result<MoveExtractor, String> {
pub fn parse_single_game_native(
pgn: &str,
store_legal_moves: bool,
) -> Result<MoveExtractor, String> {
let mut reader = Reader::new(Cursor::new(pgn));
let mut extractor = MoveExtractor::new();
let mut extractor = MoveExtractor::new(store_legal_moves);
match reader.read_game(&mut extractor) {
Ok(Some(_)) => Ok(extractor),
Ok(None) => Err("No game found in PGN".to_string()),
Expand All @@ -385,6 +449,7 @@ pub fn parse_single_game_native(pgn: &str) -> Result<MoveExtractor, String> {
pub fn parse_multiple_games_native(
pgns: &Vec<String>,
num_threads: Option<usize>,
store_legal_moves: bool,
) -> Result<Vec<MoveExtractor>, String> {
let num_threads = num_threads.unwrap_or_else(num_cpus::get);

Expand All @@ -396,14 +461,15 @@ pub fn parse_multiple_games_native(

thread_pool.install(|| {
pgns.par_iter()
.map(|pgn| parse_single_game_native(pgn))
.map(|pgn| parse_single_game_native(pgn, store_legal_moves))
.collect()
})
}

fn _parse_game_moves_from_arrow_chunks_native(
pgn_chunked_array: &PyChunkedArray,
num_threads: Option<usize>,
store_legal_moves: bool,
) -> Result<Vec<MoveExtractor>, String> {
let num_threads = num_threads.unwrap_or_else(num_cpus::get);
let thread_pool = ThreadPoolBuilder::new()
Expand Down Expand Up @@ -440,7 +506,7 @@ fn _parse_game_moves_from_arrow_chunks_native(
thread_pool.install(|| {
pgn_str_slices
.par_iter()
.map(|&pgn_s| parse_single_game_native(pgn_s))
.map(|&pgn_s| parse_single_game_native(pgn_s, store_legal_moves))
.collect::<Result<Vec<MoveExtractor>, String>>()
})
}
Expand All @@ -449,25 +515,33 @@ fn _parse_game_moves_from_arrow_chunks_native(
// TODO check if I can call py.allow_threads and release GIL
// see https://docs.rs/pyo3-arrow/0.10.1/pyo3_arrow/
#[pyfunction]
#[pyo3(signature = (pgn, store_legal_moves = false))]
/// Parses a single PGN game string.
fn parse_game(pgn: &str) -> PyResult<MoveExtractor> {
parse_single_game_native(pgn).map_err(pyo3::exceptions::PyValueError::new_err)
fn parse_game(pgn: &str, store_legal_moves: bool) -> PyResult<MoveExtractor> {
parse_single_game_native(pgn, store_legal_moves)
.map_err(pyo3::exceptions::PyValueError::new_err)
}

/// In parallel, parse a set of games
#[pyfunction]
#[pyo3(signature = (pgns, num_threads=None))]
fn parse_games(pgns: Vec<String>, num_threads: Option<usize>) -> PyResult<Vec<MoveExtractor>> {
parse_multiple_games_native(&pgns, num_threads).map_err(pyo3::exceptions::PyValueError::new_err)
#[pyo3(signature = (pgns, num_threads=None, store_legal_moves=false))]
fn parse_games(
pgns: Vec<String>,
num_threads: Option<usize>,
store_legal_moves: bool,
) -> PyResult<Vec<MoveExtractor>> {
parse_multiple_games_native(&pgns, num_threads, store_legal_moves)
.map_err(pyo3::exceptions::PyValueError::new_err)
}

#[pyfunction]
#[pyo3(signature = (pgn_chunked_array, num_threads=None))]
#[pyo3(signature = (pgn_chunked_array, num_threads=None, store_legal_moves=false))]
fn parse_game_moves_arrow_chunked_array(
pgn_chunked_array: PyChunkedArray,
num_threads: Option<usize>,
store_legal_moves: bool,
) -> PyResult<Vec<MoveExtractor>> {
_parse_game_moves_from_arrow_chunks_native(&pgn_chunked_array, num_threads)
_parse_game_moves_from_arrow_chunks_native(&pgn_chunked_array, num_threads, store_legal_moves)
.map_err(pyo3::exceptions::PyValueError::new_err)
}

Expand Down
Loading