Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is.
- Add support for image-text data calibration in PTQ for Nemotron VL models.
- Add PTQ support for Nemotron Parse.
- Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes

0.41 (2026-01-19)
^^^^^^^^^^^^^^^^^
Expand Down
135 changes: 2 additions & 133 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,42 +1147,13 @@ def _is_same_type_cast(self, node: onnx.NodeProto) -> bool:
output_type = utils.get_cast_to_type(node)
return all(inp_type == output_type for inp_type in input_types) and input_types is not None

def _is_sequential_cast(self, node: onnx.NodeProto) -> bool:
assert node.op_type == "Cast"
output_type = utils.get_cast_to_type(node)

# Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed
# Cast to low precision -> cast to high precision affects precision and should not be removed
precision_order = [
TensorProto.DOUBLE,
TensorProto.FLOAT,
TensorProto.FLOAT16,
TensorProto.BFLOAT16,
]
consumers = [
n for n in utils.get_consumer_nodes(self.model, node.output[0]) if n.op_type == "Cast"
]

# If the first cast has additional consumers, we should not remove it
if len(consumers) != 1:
return False

next_node = consumers[0]
first_cast_type = output_type
second_cast_type = utils.get_cast_to_type(next_node)

return (
first_cast_type in precision_order
and second_cast_type in precision_order
and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type)
)

def _remove_redundant_casts(self):
"""Removes both sequential casts and casts that don't change precision.

This method optimizes the graph by removing unnecessary cast operations that either:
1. Don't actually change the data type
2. Could be replaced by a single cast operation
3. Can be folded into a preceding Constant node
"""
if self.custom_ops:
self.model = self._propagate_types_shapes_custom_ops(self.model)
Expand All @@ -1198,35 +1169,7 @@ def _remove_redundant_casts(self):
check_type=True,
)

nodes_to_remove = []
for node in self.model.graph.node:
if node.op_type == "Cast":
# Find cast nodes that don't change precision
if self._is_same_type_cast(node):
nodes_to_remove.append(node)
self._bypass_cast_node(node)
logger.debug(f"Found redundant same-type cast: {node.name}")
continue

# Find sequential casts that don't change precision
if self._is_sequential_cast(node):
nodes_to_remove.append(node)
self._bypass_cast_node(node)
logger.debug(f"Found removable double-cast: {node.name}")

# Find foldable Constant -> Cast. Initializers are handled by _convert_initializers.
if self._is_foldable_constant_cast_pattern(node):
nodes_to_remove.append(node)
cast_producers = utils.get_producer_nodes(self.model, node.input[0])
assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant"
constant_producer = cast_producers[0]
self._convert_constant_values(constant_producer, node)
self._bypass_cast_node(node)
logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}")

logger.debug(f"Removing redundant casts: {[n.name for n in nodes_to_remove]}")
for node in nodes_to_remove:
self.model.graph.node.remove(node)
self.model = onnx_utils.remove_redundant_casts(self.model)

def _fix_network_output_names(self):
modified = False
Expand Down Expand Up @@ -1360,80 +1303,6 @@ def _get_tensor_type(self, tensor_name):
return self.initializer_map[tensor_name].data_type
raise Exception(f"did not find tensor {tensor_name}")

def _convert_constant_values(self, const_node, cast_node: onnx.NodeProto) -> None:
original_tensor = const_node.attribute[0].t
if original_tensor.data_type == onnx.TensorProto.BFLOAT16:
original_data = onnx_utils.read_f16_tensor_as_fp32(original_tensor)
else:
original_data = onnx.numpy_helper.to_array(original_tensor)

# Precompute casted value
cast_to_type = utils.get_cast_to_type(cast_node)
cast_dtype = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type)

# Handle bfloat16 conversion manually since numpy doesn't support it natively
if cast_to_type == onnx.TensorProto.BFLOAT16:
casted_data = original_data.astype(ml_dtypes.bfloat16)
else:
casted_data = original_data.astype(cast_dtype)

# Create a new constant node with casted data
if cast_to_type == onnx.TensorProto.BFLOAT16:
# Create TensorProto manually for bfloat16
tensor_proto = onnx.TensorProto()
tensor_proto.name = const_node.output[0]
tensor_proto.data_type = onnx.TensorProto.BFLOAT16
tensor_proto.dims.extend(casted_data.shape)
# Convert bfloat16 to raw bytes
bf16_bytes = casted_data.astype(ml_dtypes.bfloat16).view(np.uint16)
tensor_proto.raw_data = bf16_bytes.tobytes()
else:
# Create tensor manually to ensure proper handling
tensor_proto = onnx.numpy_helper.from_array(casted_data)
tensor_proto.name = const_node.output[0]

new_const_node = onnx.helper.make_node(
"Constant",
inputs=[],
outputs=const_node.output,
value=tensor_proto,
name=const_node.name,
)

# Replace the original constant node with the new constant node
# The scope of this function is to convert the constant node data. Removing the cast is done later.
for node in utils.get_consumer_nodes(self.model, const_node.name):
for i, input_name in enumerate(node.input):
if input_name == const_node.name:
node.input[i] = new_const_node.output[0]
break

const_idx = -1
for i, node in enumerate(self.model.graph.node):
if node == const_node:
const_idx = i
break

self.model.graph.node.remove(const_node)
self.model.graph.node.insert(const_idx, new_const_node)
# The Cast node is the sole consumer of the Constant node, guaranteed by _is_foldable_constant_cast_pattern
cast_node.input[0] = new_const_node.output[0]

def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool:
"""Constant -> Cast and Cast is the only consumer of the Constant node."""
assert node.op_type == "Cast"

producer = utils.get_producer_nodes(self.model, node.input[0])

const_producer = (
producer[0] if len(producer) == 1 and producer[0].op_type == "Constant" else None
)

if const_producer:
get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0])
return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node
return False

def _sanitize_model(self):
graph_sanitizer = GraphSanitizer(
self.model,
Expand Down
57 changes: 50 additions & 7 deletions modelopt/onnx/export/fp8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch
from onnx_graphsurgeon.ir.tensor import LazyValues

from modelopt.onnx.logging_config import logger

from .base_exporter import ONNXQuantExporter


Expand All @@ -45,13 +47,13 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
Even though modelopt supports FP8 onnx export, the weights are represented in fp32 + QDQ.
The storage is therefore very bad. In this function,
Q nodes will get removed from the weights and have only DQ nodes with those converted FP8
weights in the output model.
weights in the output model. TRT custom ops are converted to native ONNX DequantizeLinear.

Parameters:
onnx_model: ONNX model with FP32/FP16 weights and QDQ nodes.
onnx_model: ONNX model with FP32/FP16 weights and TRT_FP8 QDQ nodes.

Returns:
ONNX model with FP8 weights and only DQ nodes for weights (QDQ preserved for activations).
ONNX model with FP8 weights and native ONNX DQ nodes for weights (QDQ preserved for activations).
"""
start_time = time.time()
print("Replacing all (fp32 weights + fp8 QDQ) with (fp8 weights + DQ)...")
Expand All @@ -62,7 +64,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:

for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
# Should not remove input QDQ
# Should not remove input QDQ (only process weight quantization)
if not isinstance(node.inputs[0], gs.Constant):
continue

Expand All @@ -88,7 +90,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values)

node.outputs.clear()
# DQ Op is separated out
# Convert TRT DQ to native ONNX DequantizeLinear with FP8 weights
dq_op.inputs[0] = onnx_weights_fp8
dq_op.op = "DequantizeLinear"
dq_op.outputs[0].dtype = dq_op.inputs[1].dtype
Expand All @@ -101,5 +103,46 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:

@staticmethod
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Post-processes the ONNX model for FP8 quantization."""
return onnx_model
"""Post-processes the ONNX model for FP8 quantization.

Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear:
- TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1
- TRT_FP8DequantizeLinear -> DequantizeLinear

Args:
onnx_model: The ONNX model containing TRT_FP8 quantization nodes.

Returns:
The post-processed ONNX model with native ONNX quantization ops.
"""
logger.info("Post-processing FP8 quantized model")
graph = gs.import_onnx(onnx_model)

# Convert TRT_FP8QuantizeLinear to native QuantizeLinear
for node in graph.nodes:
if node.op == "TRT_FP8QuantizeLinear":
node.op = "QuantizeLinear"
# Add FP8 zero_point if not present
if len(node.inputs) == 2:
# Create FP8 zero point constant
zp_tensor = onnx.TensorProto()
zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN
zp_tensor.dims.extend([1]) # 1-element tensor
zp_tensor.raw_data = b"\x00" # Zero in FP8
zp_values = LazyValues(zp_tensor)
zero_point = gs.Constant(node.name + "_zero_point", zp_values)
node.inputs.append(zero_point)
# Add saturate attribute for FP8
node.attrs["saturate"] = 1
logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear")

# Convert TRT_FP8DequantizeLinear to native DequantizeLinear
for node in graph.nodes:
if node.op == "TRT_FP8DequantizeLinear":
node.op = "DequantizeLinear"
logger.debug(
f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear"
)

graph.cleanup().toposort()
return gs.export_onnx(graph)
6 changes: 3 additions & 3 deletions modelopt/onnx/export/nvfp4_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to process")

for node in fp4_qdq_nodes:
idx = initializer_indices.get(node.input[0], None)
idx = initializer_indices.get(node.input[0])
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."

tensor = initializers[idx]
Expand Down Expand Up @@ -259,7 +259,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
fp4_qdq_nodes = [node for node in graph.node if node.op_type == "TRT_FP4QDQ"]

for node in fp4_qdq_nodes:
idx = initializer_indices.get(node.input[0], None)
idx = initializer_indices.get(node.input[0])
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."

tensor = initializers[idx]
Expand Down Expand Up @@ -365,7 +365,7 @@ def _cast_input_dtypes(node: onnx.NodeProto, precision_dtype: str):
logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to convert")

for node in fp4_qdq_nodes:
idx = initializer_indices.get(node.input[0], None)
idx = initializer_indices.get(node.input[0])
assert idx is not None, f"Initializer for weight '{node.input[0]}' not found."
initializers_to_delete.append(graph.initializer[idx].name)

Expand Down
Loading
Loading