From c04c7e5369d8f322fc20f3919858f274d9fdf103 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Fri, 24 Apr 2026 15:53:34 +0200 Subject: [PATCH] Update to new ONNX split protobuf + weight API and MLOperandDataType enum --- Cargo.toml | 1 - src/python/context.rs | 47 ++++++++-- src/python/graph.rs | 181 +++++++----------------------------- src/python/graph_builder.rs | 35 +++---- 4 files changed, 94 insertions(+), 170 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4585be5..2c3c54f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,6 @@ pyo3 = { version = "0.28", features = ["extension-module"] } rustnn = { git = "https://github.com/rustnn/rustnn", branch = "main" } serde_json = "1.0" webnn-graph = { git = "https://github.com/rustnn/webnn-graph", branch = "main" } -safetensors = "0.7" half = "2.4" regex = "1.11" # Optional runtime dependencies diff --git a/src/python/context.rs b/src/python/context.rs index d27d22d..7d040c8 100644 --- a/src/python/context.rs +++ b/src/python/context.rs @@ -18,6 +18,8 @@ use rustnn::Operation; #[cfg(feature = "onnx-runtime")] use rustnn::executors::onnx::{run_onnx_with_inputs, OnnxInput}; +#[cfg(feature = "onnx-runtime")] +use std::borrow::Cow; #[cfg(all(target_os = "macos", feature = "coreml-runtime"))] use rustnn::executors::coreml::run_coreml_zeroed_cached_with_weights; @@ -340,6 +342,19 @@ impl PyMLContext { std::fs::write(output_path, &converted.data).map_err(|e| { pyo3::exceptions::PyIOError::new_err(format!("Failed to write ONNX file: {}", e)) })?; + if let Some(weights) = &converted.weights_data { + let sidecar = std::path::Path::new(output_path) + .parent() + .unwrap_or_else(|| std::path::Path::new(".")) + .join(rustnn::ONNX_EXTERNAL_WEIGHTS_FILENAME); + std::fs::write(&sidecar, weights).map_err(|e| { + pyo3::exceptions::PyIOError::new_err(format!( + "Failed to write ONNX external weights file `{}`: {}", + sidecar.display(), + e + )) + })?; + } Ok(()) } @@ -1107,18 +1122,33 @@ impl PyMLContext { pyo3::exceptions::PyRuntimeError::new_err(format!("ONNX conversion failed: {}", e)) })?; - let session = ort::session::Session::builder() + let rustnn::converters::ConvertedGraph { + data, weights_data, .. + } = converted; + let mut builder = ort::session::Session::builder() .map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!("Session builder failed: {}", e)) })? .with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level1) .map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!("Set opt level failed: {}", e)) - })? - .commit_from_memory(&converted.data) - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("Load model failed: {}", e)) })?; + if let Some(weights) = weights_data { + builder = builder + .with_external_initializer_file_in_memory( + rustnn::ONNX_EXTERNAL_WEIGHTS_FILENAME, + Cow::Owned(weights), + ) + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!( + "Set external initializer failed: {}", + e + )) + })?; + } + let session = builder.commit_from_memory(&data).map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("Load model failed: {}", e)) + })?; let session_arc = std::sync::Arc::new(session); *session_guard = Some(std::sync::Arc::clone(&session_arc)); @@ -1265,7 +1295,12 @@ impl PyMLContext { } // Execute with ONNX runtime - let onnx_outputs = run_onnx_with_inputs(&converted.data, onnx_inputs).map_err(|e| { + let onnx_outputs = run_onnx_with_inputs( + &converted.data, + converted.weights_data.as_deref(), + onnx_inputs, + ) + .map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!("ONNX execution failed: {}", e)) })?; diff --git a/src/python/graph.rs b/src/python/graph.rs index 408478e..d823cd6 100644 --- a/src/python/graph.rs +++ b/src/python/graph.rs @@ -9,7 +9,6 @@ use pyo3::exceptions::PyIOError; use pyo3::prelude::*; use rustnn::graph::GraphInfo; use rustnn::webnn_json; -use safetensors::SafeTensors; use std::fs; use std::path::Path; @@ -358,9 +357,10 @@ impl PyMLGraph { /// /// Args: /// path: File path to load the graph from (e.g., "model.webnn") - /// manifest_path: Optional path to manifest.json file for external weights - /// weights_path: Optional path to weights file for external weights, or - /// a single .safetensors file (when manifest_path is not provided) + /// manifest_path: Optional path to manifest.json for manifest + raw weights layout + /// weights_path: Path to a `.safetensors` file, or to a raw `.weights` blob (with manifest + /// passed explicitly or discovered next to the graph). Relative paths are resolved from + /// the graph file’s parent directory (same rules as `webnn-graph`). /// /// Returns: /// MLGraph: The loaded graph @@ -373,6 +373,7 @@ impl PyMLGraph { /// graph = MLGraph.load("my_model.webnn") /// graph = MLGraph.load("model.webnn", manifest_path="manifest.json", weights_path="model.weights") /// graph = MLGraph.load("model.webnn", weights_path="model.safetensors") + /// graph = MLGraph.load("model.webnn", weights_path="custom_name.safetensors") #[staticmethod] #[pyo3(signature = (path, manifest_path=None, weights_path=None))] fn load(path: &str, manifest_path: Option<&str>, weights_path: Option<&str>) -> PyResult { @@ -400,7 +401,7 @@ impl PyMLGraph { }; // Resolve external weight references if present - Self::resolve_external_weights(&mut graph_json, manifest_path, weights_path)?; + Self::resolve_external_weights(&mut graph_json, path_obj, manifest_path, weights_path)?; // Convert GraphJson to GraphInfo let graph_info = webnn_json::from_graph_json(&graph_json) @@ -415,154 +416,42 @@ impl PyMLGraph { Self { graph_info } } - /// Resolve external weight references in a GraphJson - /// - /// This function loads manifest and weights files and resolves all weight references to inline bytes. - /// - /// If weights_path are not provided, returns immediately (no external weights). + /// Delegates to [`webnn_graph::resolve_external_weights`], then surfaces a Python error if any + /// `@weights` refs remain (e.g. partial safetensors coverage). fn resolve_external_weights( graph_json: &mut webnn_graph::ast::GraphJson, - manifest_path: Option<&str>, - weights_path: Option<&str>, - ) -> PyResult<()> { - if let Some(st_path) = Self::resolve_safetensors_path(weights_path) { - return Self::resolve_safetensors_weights(graph_json, st_path); - } - - Self::resolve_manifest_weights(graph_json, manifest_path, weights_path) - } - - fn resolve_safetensors_path(weights_path: Option<&str>) -> Option<&str> { - fn is_safetensors(path: &str) -> bool { - path.ends_with(".safetensors") || path.ends_with(".safetensor") - } - - if let Some(p) = weights_path { - if is_safetensors(p) { - return Some(p); - } - } - None - } - - fn resolve_safetensors_weights( - graph_json: &mut webnn_graph::ast::GraphJson, - safetensors_path: &str, - ) -> PyResult<()> { - use std::collections::HashMap; - use webnn_graph::ast::ConstInit; - - let st_bytes = fs::read(safetensors_path) - .map_err(|e| PyIOError::new_err(format!("Failed to read safetensors: {}", e)))?; - let st = SafeTensors::deserialize(&st_bytes) - .map_err(|e| PyIOError::new_err(format!("Failed to parse safetensors: {}", e)))?; - - // Map sanitized key -> original key to support refs where "." and "::" - // were replaced by "_" and "__". - let mut sanitized_map: HashMap = HashMap::new(); - for key in st.names() { - sanitized_map.insert(key.replace("::", "__").replace('.', "_"), key.to_string()); - } - - for (_name, const_decl) in graph_json.consts.iter_mut() { - if let ConstInit::Weights { r#ref } = &const_decl.init { - let tensor_view = st - .tensor(r#ref) - .or_else(|_| { - sanitized_map - .get(r#ref) - .ok_or_else(|| { - safetensors::SafeTensorError::TensorNotFound(r#ref.clone()) - }) - .and_then(|orig| st.tensor(orig)) - }) - .map_err(|e| { - PyIOError::new_err(format!( - "Weight '{}' not found in safetensors '{}': {}", - r#ref, safetensors_path, e - )) - })?; - - const_decl.init = ConstInit::InlineBytes { - bytes: tensor_view.data().to_vec(), - }; - } - } - - Ok(()) - } - - fn resolve_manifest_weights( - graph_json: &mut webnn_graph::ast::GraphJson, + graph_path: &Path, manifest_path: Option<&str>, weights_path: Option<&str>, ) -> PyResult<()> { use webnn_graph::ast::ConstInit; - use webnn_graph::weights::WeightsManifest; - - // If no manifest path provided, assume no external weights - let manifest_path = match manifest_path { - Some(p) => p, - None => return Ok(()), - }; - // If no weights path provided, assume no external weights - let weights_path = match weights_path { - Some(p) => p, - None => return Ok(()), - }; - - // Load manifest - let manifest_content = fs::read_to_string(manifest_path) - .map_err(|e| PyIOError::new_err(format!("Failed to read manifest: {}", e)))?; - let manifest: WeightsManifest = serde_json::from_str(&manifest_content) - .map_err(|e| PyIOError::new_err(format!("Failed to parse manifest: {}", e)))?; - - // Load weights file - let weights_data = fs::read(weights_path) - .map_err(|e| PyIOError::new_err(format!("Failed to read weights: {}", e)))?; - - // Create a sanitized lookup map: dots and colons in manifest keys -> underscores. - // Some graphs carry sanitized weight refs while others keep original dotted refs. - // We support both formats by checking exact refs first, then sanitized refs. - use std::collections::HashMap; - let sanitized_manifest: HashMap = manifest - .tensors - .iter() - .map(|(key, value)| (key.replace("::", "__").replace('.', "_"), value)) - .collect(); - - // Resolve weight references in constants - for (_name, const_decl) in graph_json.consts.iter_mut() { - if let ConstInit::Weights { r#ref } = &const_decl.init { - // Prefer exact key lookup for modern manifests, then fallback to sanitized. - let tensor_entry = manifest - .tensors - .get(r#ref) - .or_else(|| sanitized_manifest.get(r#ref).copied()); - - if let Some(tensor_entry) = tensor_entry { - let offset = tensor_entry.byte_offset as usize; - let length = tensor_entry.byte_length as usize; - - // Extract bytes from weights file - if offset + length > weights_data.len() { - return Err(PyIOError::new_err(format!( - "Weight '{}' offset/length exceeds weights file size", - r#ref - ))); - } - let bytes = weights_data[offset..offset + length].to_vec(); - - // Replace weight reference with inline bytes - const_decl.init = ConstInit::InlineBytes { bytes }; - } else { - return Err(PyIOError::new_err(format!( - "Weight '{}' not found in manifest", - r#ref - ))); - } - } + webnn_graph::resolve_external_weights( + graph_json, + graph_path, + weights_path, + manifest_path, + ) + .map_err(|e| { + PyIOError::new_err(format!( + "Failed to resolve external weights: {e}. \ + Pass weights_path to a `.safetensors` file, or manifest_path and weights_path for manifest + raw blob, \ + or place sidecar files next to the graph." + )) + })?; + + let pending_count = graph_json + .consts + .values() + .filter(|c| matches!(c.init, ConstInit::Weights { .. })) + .count(); + if pending_count > 0 { + return Err(PyIOError::new_err(format!( + "Graph still has {pending_count} external weight reference(s) after resolution. \ + Pass weights_path to your `.safetensors` file, or manifest_path + weights_path for manifest + raw blob, \ + or place model.safetensors / {{stem}}.safetensors or manifest + weights next to the graph. \ + Rebuild the native extension (`pip install -e .` / maturin) if dependency changes are not picked up.", + ))); } Ok(()) diff --git a/src/python/graph_builder.rs b/src/python/graph_builder.rs index 5cb348f..dcfdd57 100644 --- a/src/python/graph_builder.rs +++ b/src/python/graph_builder.rs @@ -13,6 +13,7 @@ use pyo3::types::PyDict; use rustnn::graph::{ to_dimension_vector, ConstantData, DataType, GraphInfo, Operand, OperandDescriptor, OperandKind, }; +use rustnn::operator_enums::MLOperandDataType; use rustnn::operator_options::{ MLArgMinMaxOptions, MLBatchNormalizationOptions, MLClampOptions, MLConv2dOptions, MLConvTranspose2dOptions, MLDimension, MLEluOptions, MLGatherOptions, MLGemmOptions, @@ -2528,9 +2529,9 @@ impl PyMLGraphBuilder { .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; // Parse output data type, default to int64 - let output_type = match output_data_type { - Some("int32") => DataType::Int32, - Some("int64") | None => DataType::Int64, + let (output_type, ml_output_dtype) = match output_data_type { + Some("int32") => (DataType::Int32, MLOperandDataType::Int32), + Some("int64") | None => (DataType::Int64, MLOperandDataType::Int64), Some(other) => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid output_data_type '{}', must be 'int32' or 'int64'", @@ -2551,7 +2552,7 @@ impl PyMLGraphBuilder { let arg_opts = MLArgMinMaxOptions { label: String::new(), keep_dimensions, - output_data_type: output_data_type.unwrap_or("int64").to_string(), + output_data_type: ml_output_dtype, }; self.push_op(Operation::ArgMax { @@ -2602,9 +2603,9 @@ impl PyMLGraphBuilder { .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?; // Parse output data type, default to int64 - let output_type = match output_data_type { - Some("int32") => DataType::Int32, - Some("int64") | None => DataType::Int64, + let (output_type, ml_output_dtype) = match output_data_type { + Some("int32") => (DataType::Int32, MLOperandDataType::Int32), + Some("int64") | None => (DataType::Int64, MLOperandDataType::Int64), Some(other) => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid output_data_type '{}', must be 'int32' or 'int64'", @@ -2625,7 +2626,7 @@ impl PyMLGraphBuilder { let arg_opts = MLArgMinMaxOptions { label: String::new(), keep_dimensions, - output_data_type: output_data_type.unwrap_or("int64").to_string(), + output_data_type: ml_output_dtype, }; self.push_op(Operation::ArgMin { @@ -2662,14 +2663,14 @@ impl PyMLGraphBuilder { let output_shape = infer_cast_shape(&input.descriptor.static_or_max_shape()); // Parse target data type - let target_type = match data_type { - "float32" => DataType::Float32, - "float16" => DataType::Float16, - "int32" => DataType::Int32, - "uint32" => DataType::Uint32, - "int8" => DataType::Int8, - "uint8" => DataType::Uint8, - "int64" => DataType::Int64, + let (target_type, ml_dtype) = match data_type { + "float32" => (DataType::Float32, MLOperandDataType::Float32), + "float16" => (DataType::Float16, MLOperandDataType::Float16), + "int32" => (DataType::Int32, MLOperandDataType::Int32), + "uint32" => (DataType::Uint32, MLOperandDataType::Uint32), + "int8" => (DataType::Int8, MLOperandDataType::Int8), + "uint8" => (DataType::Uint8, MLOperandDataType::Uint8), + "int64" => (DataType::Int64, MLOperandDataType::Int64), other => { return Err(pyo3::exceptions::PyValueError::new_err(format!( "Invalid data_type '{}', must be one of: float32, float16, int32, uint32, int8, uint8, int64", @@ -2689,7 +2690,7 @@ impl PyMLGraphBuilder { self.push_op(Operation::Cast { input: input.id, - to: data_type.to_string(), + data_type: ml_dtype, options: None, outputs: vec![output_id], });