Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ pyo3 = { version = "0.28", features = ["extension-module"] }
rustnn = { git = "https://github.com/rustnn/rustnn", branch = "main" }
serde_json = "1.0"
webnn-graph = { git = "https://github.com/rustnn/webnn-graph", branch = "main" }
safetensors = "0.7"
half = "2.4"
regex = "1.11"
# Optional runtime dependencies
Expand Down
47 changes: 41 additions & 6 deletions src/python/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use rustnn::Operation;

#[cfg(feature = "onnx-runtime")]
use rustnn::executors::onnx::{run_onnx_with_inputs, OnnxInput};
#[cfg(feature = "onnx-runtime")]
use std::borrow::Cow;

#[cfg(all(target_os = "macos", feature = "coreml-runtime"))]
use rustnn::executors::coreml::run_coreml_zeroed_cached_with_weights;
Expand Down Expand Up @@ -340,6 +342,19 @@ impl PyMLContext {
std::fs::write(output_path, &converted.data).map_err(|e| {
pyo3::exceptions::PyIOError::new_err(format!("Failed to write ONNX file: {}", e))
})?;
if let Some(weights) = &converted.weights_data {
let sidecar = std::path::Path::new(output_path)
.parent()
.unwrap_or_else(|| std::path::Path::new("."))
.join(rustnn::ONNX_EXTERNAL_WEIGHTS_FILENAME);
std::fs::write(&sidecar, weights).map_err(|e| {
pyo3::exceptions::PyIOError::new_err(format!(
"Failed to write ONNX external weights file `{}`: {}",
sidecar.display(),
e
))
})?;
}

Ok(())
}
Expand Down Expand Up @@ -1107,18 +1122,33 @@ impl PyMLContext {
pyo3::exceptions::PyRuntimeError::new_err(format!("ONNX conversion failed: {}", e))
})?;

let session = ort::session::Session::builder()
let rustnn::converters::ConvertedGraph {
data, weights_data, ..
} = converted;
let mut builder = ort::session::Session::builder()
.map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Session builder failed: {}", e))
})?
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level1)
.map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Set opt level failed: {}", e))
})?
.commit_from_memory(&converted.data)
.map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Load model failed: {}", e))
})?;
if let Some(weights) = weights_data {
builder = builder
.with_external_initializer_file_in_memory(
rustnn::ONNX_EXTERNAL_WEIGHTS_FILENAME,
Cow::Owned(weights),
)
.map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"Set external initializer failed: {}",
e
))
})?;
}
let session = builder.commit_from_memory(&data).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Load model failed: {}", e))
})?;

let session_arc = std::sync::Arc::new(session);
*session_guard = Some(std::sync::Arc::clone(&session_arc));
Expand Down Expand Up @@ -1265,7 +1295,12 @@ impl PyMLContext {
}

// Execute with ONNX runtime
let onnx_outputs = run_onnx_with_inputs(&converted.data, onnx_inputs).map_err(|e| {
let onnx_outputs = run_onnx_with_inputs(
&converted.data,
converted.weights_data.as_deref(),
onnx_inputs,
)
.map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("ONNX execution failed: {}", e))
})?;

Expand Down
181 changes: 35 additions & 146 deletions src/python/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use pyo3::exceptions::PyIOError;
use pyo3::prelude::*;
use rustnn::graph::GraphInfo;
use rustnn::webnn_json;
use safetensors::SafeTensors;
use std::fs;
use std::path::Path;

Expand Down Expand Up @@ -358,9 +357,10 @@ impl PyMLGraph {
///
/// Args:
/// path: File path to load the graph from (e.g., "model.webnn")
/// manifest_path: Optional path to manifest.json file for external weights
/// weights_path: Optional path to weights file for external weights, or
/// a single .safetensors file (when manifest_path is not provided)
/// manifest_path: Optional path to manifest.json for manifest + raw weights layout
/// weights_path: Path to a `.safetensors` file, or to a raw `.weights` blob (with manifest
/// passed explicitly or discovered next to the graph). Relative paths are resolved from
/// the graph file’s parent directory (same rules as `webnn-graph`).
///
/// Returns:
/// MLGraph: The loaded graph
Expand All @@ -373,6 +373,7 @@ impl PyMLGraph {
/// graph = MLGraph.load("my_model.webnn")
/// graph = MLGraph.load("model.webnn", manifest_path="manifest.json", weights_path="model.weights")
/// graph = MLGraph.load("model.webnn", weights_path="model.safetensors")
/// graph = MLGraph.load("model.webnn", weights_path="custom_name.safetensors")
#[staticmethod]
#[pyo3(signature = (path, manifest_path=None, weights_path=None))]
fn load(path: &str, manifest_path: Option<&str>, weights_path: Option<&str>) -> PyResult<Self> {
Expand Down Expand Up @@ -400,7 +401,7 @@ impl PyMLGraph {
};

// Resolve external weight references if present
Self::resolve_external_weights(&mut graph_json, manifest_path, weights_path)?;
Self::resolve_external_weights(&mut graph_json, path_obj, manifest_path, weights_path)?;

// Convert GraphJson to GraphInfo
let graph_info = webnn_json::from_graph_json(&graph_json)
Expand All @@ -415,154 +416,42 @@ impl PyMLGraph {
Self { graph_info }
}

/// Resolve external weight references in a GraphJson
///
/// This function loads manifest and weights files and resolves all weight references to inline bytes.
///
/// If weights_path are not provided, returns immediately (no external weights).
/// Delegates to [`webnn_graph::resolve_external_weights`], then surfaces a Python error if any
/// `@weights` refs remain (e.g. partial safetensors coverage).
fn resolve_external_weights(
graph_json: &mut webnn_graph::ast::GraphJson,
manifest_path: Option<&str>,
weights_path: Option<&str>,
) -> PyResult<()> {
if let Some(st_path) = Self::resolve_safetensors_path(weights_path) {
return Self::resolve_safetensors_weights(graph_json, st_path);
}

Self::resolve_manifest_weights(graph_json, manifest_path, weights_path)
}

fn resolve_safetensors_path(weights_path: Option<&str>) -> Option<&str> {
fn is_safetensors(path: &str) -> bool {
path.ends_with(".safetensors") || path.ends_with(".safetensor")
}

if let Some(p) = weights_path {
if is_safetensors(p) {
return Some(p);
}
}
None
}

fn resolve_safetensors_weights(
graph_json: &mut webnn_graph::ast::GraphJson,
safetensors_path: &str,
) -> PyResult<()> {
use std::collections::HashMap;
use webnn_graph::ast::ConstInit;

let st_bytes = fs::read(safetensors_path)
.map_err(|e| PyIOError::new_err(format!("Failed to read safetensors: {}", e)))?;
let st = SafeTensors::deserialize(&st_bytes)
.map_err(|e| PyIOError::new_err(format!("Failed to parse safetensors: {}", e)))?;

// Map sanitized key -> original key to support refs where "." and "::"
// were replaced by "_" and "__".
let mut sanitized_map: HashMap<String, String> = HashMap::new();
for key in st.names() {
sanitized_map.insert(key.replace("::", "__").replace('.', "_"), key.to_string());
}

for (_name, const_decl) in graph_json.consts.iter_mut() {
if let ConstInit::Weights { r#ref } = &const_decl.init {
let tensor_view = st
.tensor(r#ref)
.or_else(|_| {
sanitized_map
.get(r#ref)
.ok_or_else(|| {
safetensors::SafeTensorError::TensorNotFound(r#ref.clone())
})
.and_then(|orig| st.tensor(orig))
})
.map_err(|e| {
PyIOError::new_err(format!(
"Weight '{}' not found in safetensors '{}': {}",
r#ref, safetensors_path, e
))
})?;

const_decl.init = ConstInit::InlineBytes {
bytes: tensor_view.data().to_vec(),
};
}
}

Ok(())
}

fn resolve_manifest_weights(
graph_json: &mut webnn_graph::ast::GraphJson,
graph_path: &Path,
manifest_path: Option<&str>,
weights_path: Option<&str>,
) -> PyResult<()> {
use webnn_graph::ast::ConstInit;
use webnn_graph::weights::WeightsManifest;

// If no manifest path provided, assume no external weights
let manifest_path = match manifest_path {
Some(p) => p,
None => return Ok(()),
};

// If no weights path provided, assume no external weights
let weights_path = match weights_path {
Some(p) => p,
None => return Ok(()),
};

// Load manifest
let manifest_content = fs::read_to_string(manifest_path)
.map_err(|e| PyIOError::new_err(format!("Failed to read manifest: {}", e)))?;
let manifest: WeightsManifest = serde_json::from_str(&manifest_content)
.map_err(|e| PyIOError::new_err(format!("Failed to parse manifest: {}", e)))?;

// Load weights file
let weights_data = fs::read(weights_path)
.map_err(|e| PyIOError::new_err(format!("Failed to read weights: {}", e)))?;

// Create a sanitized lookup map: dots and colons in manifest keys -> underscores.
// Some graphs carry sanitized weight refs while others keep original dotted refs.
// We support both formats by checking exact refs first, then sanitized refs.
use std::collections::HashMap;
let sanitized_manifest: HashMap<String, _> = manifest
.tensors
.iter()
.map(|(key, value)| (key.replace("::", "__").replace('.', "_"), value))
.collect();

// Resolve weight references in constants
for (_name, const_decl) in graph_json.consts.iter_mut() {
if let ConstInit::Weights { r#ref } = &const_decl.init {
// Prefer exact key lookup for modern manifests, then fallback to sanitized.
let tensor_entry = manifest
.tensors
.get(r#ref)
.or_else(|| sanitized_manifest.get(r#ref).copied());

if let Some(tensor_entry) = tensor_entry {
let offset = tensor_entry.byte_offset as usize;
let length = tensor_entry.byte_length as usize;

// Extract bytes from weights file
if offset + length > weights_data.len() {
return Err(PyIOError::new_err(format!(
"Weight '{}' offset/length exceeds weights file size",
r#ref
)));
}
let bytes = weights_data[offset..offset + length].to_vec();

// Replace weight reference with inline bytes
const_decl.init = ConstInit::InlineBytes { bytes };
} else {
return Err(PyIOError::new_err(format!(
"Weight '{}' not found in manifest",
r#ref
)));
}
}
webnn_graph::resolve_external_weights(
graph_json,
graph_path,
weights_path,
manifest_path,
)
.map_err(|e| {
PyIOError::new_err(format!(
"Failed to resolve external weights: {e}. \
Pass weights_path to a `.safetensors` file, or manifest_path and weights_path for manifest + raw blob, \
or place sidecar files next to the graph."
))
})?;

let pending_count = graph_json
.consts
.values()
.filter(|c| matches!(c.init, ConstInit::Weights { .. }))
.count();
if pending_count > 0 {
return Err(PyIOError::new_err(format!(
"Graph still has {pending_count} external weight reference(s) after resolution. \
Pass weights_path to your `.safetensors` file, or manifest_path + weights_path for manifest + raw blob, \
or place model.safetensors / {{stem}}.safetensors or manifest + weights next to the graph. \
Rebuild the native extension (`pip install -e .` / maturin) if dependency changes are not picked up.",
)));
}

Ok(())
Expand Down
35 changes: 18 additions & 17 deletions src/python/graph_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use pyo3::types::PyDict;
use rustnn::graph::{
to_dimension_vector, ConstantData, DataType, GraphInfo, Operand, OperandDescriptor, OperandKind,
};
use rustnn::operator_enums::MLOperandDataType;
use rustnn::operator_options::{
MLArgMinMaxOptions, MLBatchNormalizationOptions, MLClampOptions, MLConv2dOptions,
MLConvTranspose2dOptions, MLDimension, MLEluOptions, MLGatherOptions, MLGemmOptions,
Expand Down Expand Up @@ -2528,9 +2529,9 @@ impl PyMLGraphBuilder {
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;

// Parse output data type, default to int64
let output_type = match output_data_type {
Some("int32") => DataType::Int32,
Some("int64") | None => DataType::Int64,
let (output_type, ml_output_dtype) = match output_data_type {
Some("int32") => (DataType::Int32, MLOperandDataType::Int32),
Some("int64") | None => (DataType::Int64, MLOperandDataType::Int64),
Some(other) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid output_data_type '{}', must be 'int32' or 'int64'",
Expand All @@ -2551,7 +2552,7 @@ impl PyMLGraphBuilder {
let arg_opts = MLArgMinMaxOptions {
label: String::new(),
keep_dimensions,
output_data_type: output_data_type.unwrap_or("int64").to_string(),
output_data_type: ml_output_dtype,
};

self.push_op(Operation::ArgMax {
Expand Down Expand Up @@ -2602,9 +2603,9 @@ impl PyMLGraphBuilder {
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;

// Parse output data type, default to int64
let output_type = match output_data_type {
Some("int32") => DataType::Int32,
Some("int64") | None => DataType::Int64,
let (output_type, ml_output_dtype) = match output_data_type {
Some("int32") => (DataType::Int32, MLOperandDataType::Int32),
Some("int64") | None => (DataType::Int64, MLOperandDataType::Int64),
Some(other) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid output_data_type '{}', must be 'int32' or 'int64'",
Expand All @@ -2625,7 +2626,7 @@ impl PyMLGraphBuilder {
let arg_opts = MLArgMinMaxOptions {
label: String::new(),
keep_dimensions,
output_data_type: output_data_type.unwrap_or("int64").to_string(),
output_data_type: ml_output_dtype,
};

self.push_op(Operation::ArgMin {
Expand Down Expand Up @@ -2662,14 +2663,14 @@ impl PyMLGraphBuilder {
let output_shape = infer_cast_shape(&input.descriptor.static_or_max_shape());

// Parse target data type
let target_type = match data_type {
"float32" => DataType::Float32,
"float16" => DataType::Float16,
"int32" => DataType::Int32,
"uint32" => DataType::Uint32,
"int8" => DataType::Int8,
"uint8" => DataType::Uint8,
"int64" => DataType::Int64,
let (target_type, ml_dtype) = match data_type {
"float32" => (DataType::Float32, MLOperandDataType::Float32),
"float16" => (DataType::Float16, MLOperandDataType::Float16),
"int32" => (DataType::Int32, MLOperandDataType::Int32),
"uint32" => (DataType::Uint32, MLOperandDataType::Uint32),
"int8" => (DataType::Int8, MLOperandDataType::Int8),
"uint8" => (DataType::Uint8, MLOperandDataType::Uint8),
"int64" => (DataType::Int64, MLOperandDataType::Int64),
other => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid data_type '{}', must be one of: float32, float16, int32, uint32, int8, uint8, int64",
Expand All @@ -2689,7 +2690,7 @@ impl PyMLGraphBuilder {

self.push_op(Operation::Cast {
input: input.id,
to: data_type.to_string(),
data_type: ml_dtype,
options: None,
outputs: vec![output_id],
});
Expand Down
Loading