Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 80 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,63 @@ use pyo3::prelude::*;
use rayon::prelude::*;
use rayon::ThreadPoolBuilder;
use shakmaty::Color;
use shakmaty::{uci::UciMove, Chess, Outcome, Position};
use shakmaty::{uci::UciMove, Chess, Outcome, Position, Square};
use std::io::Cursor;

mod comment_parsing;

// Definition of PyUciMove
#[pyclass(get_all, set_all, module = "rust_pgn_reader_python_binding")]
#[derive(Clone, Debug)]
pub struct PyUciMove {
pub from_square: u8,
pub to_square: u8,
pub promotion: Option<char>,
}

#[pymethods]
impl PyUciMove {
#[new]
fn new(from_square: u8, to_square: u8, promotion: Option<char>) -> Self {
PyUciMove {
from_square,
to_square,
promotion,
}
}

#[getter]
fn get_from_square_name(&self) -> String {
Square::new(self.from_square as u32).to_string()
}

#[getter]
fn get_to_square_name(&self) -> String {
Square::new(self.to_square as u32).to_string()
}

// __str__ method for Python representation
fn __str__(&self) -> String {
let promo_char = self.promotion.map_or("".to_string(), |p| p.to_string());
format!(
"{}{}{}",
Square::new(self.from_square as u32),
Square::new(self.to_square as u32),
promo_char
)
}

// __repr__ for a more developer-friendly representation
fn __repr__(&self) -> String {
format!(
"PyUciMove(from_square={}, to_square={}, promotion={:?})",
Square::new(self.from_square as u32),
Square::new(self.to_square as u32),
self.promotion
)
}
}

#[pyclass]
/// Holds the status of a chess position.
#[derive(Clone)]
Expand All @@ -32,11 +84,12 @@ pub struct PositionStatus {
#[pyo3(get)]
turn: bool,
}

#[pyclass]
/// A Visitor to extract SAN moves and comments from PGN movetext
pub struct MoveExtractor {
#[pyo3(get)]
moves: Vec<String>,
moves: Vec<PyUciMove>,

#[pyo3(get)]
valid_moves: bool,
Expand Down Expand Up @@ -146,9 +199,30 @@ impl Visitor for MoveExtractor {
match san_plus.san.to_move(&self.pos) {
Ok(m) => {
self.pos.play_unchecked(&m);
let uci = UciMove::from_standard(&m);
self.moves.push(uci.to_string());
self.push_castling_bitboards();
let uci_move_obj = UciMove::from_standard(&m);

match uci_move_obj {
UciMove::Normal {
from,
to,
promotion: promo_opt,
} => {
let py_uci_move = PyUciMove {
from_square: from as u8,
to_square: to as u8,
promotion: promo_opt.map(|p| p.char()),
};
self.moves.push(py_uci_move);
self.push_castling_bitboards();
}
_ => {
// This case handles UciMove::Put and UciMove::Null,
// which are not expected from standard PGN moves
// that PyUciMove is designed to represent.
eprintln!("Unexpected UCI move type from standard PGN move: {:?}. Game moves might be invalid.", uci_move_obj);
self.valid_moves = false;
}
}
}
Err(err) => {
eprintln!("error in game: {} {}", err, san_plus);
Expand Down Expand Up @@ -285,5 +359,6 @@ fn rust_pgn_reader_python_binding(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(parse_games, m)?)?;
m.add_class::<MoveExtractor>()?;
m.add_class::<PositionStatus>()?;
m.add_class::<PyUciMove>()?;
Ok(())
}