diff --git a/CHANGELOG.rst b/CHANGELOG.rst index bbbe6ab9e..e02fb173c 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -22,6 +22,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add PTQ support for GLM-4.7, including loading MTP layer weights from a separate ``mtp.safetensors`` file and export as-is. - Add support for image-text data calibration in PTQ for Nemotron VL models. - Add PTQ support for Nemotron Parse. +- Replace modelopt FP8 QDQ nodes with native ONNX QDQ nodes 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 278486c4b..d7987e649 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -1147,42 +1147,13 @@ def _is_same_type_cast(self, node: onnx.NodeProto) -> bool: output_type = utils.get_cast_to_type(node) return all(inp_type == output_type for inp_type in input_types) and input_types is not None - def _is_sequential_cast(self, node: onnx.NodeProto) -> bool: - assert node.op_type == "Cast" - output_type = utils.get_cast_to_type(node) - - # Cast to high precision -> cast to low precision, first cast has no impact and can be safely removed - # Cast to low precision -> cast to high precision affects precision and should not be removed - precision_order = [ - TensorProto.DOUBLE, - TensorProto.FLOAT, - TensorProto.FLOAT16, - TensorProto.BFLOAT16, - ] - consumers = [ - n for n in utils.get_consumer_nodes(self.model, node.output[0]) if n.op_type == "Cast" - ] - - # If the first cast has additional consumers, we should not remove it - if len(consumers) != 1: - return False - - next_node = consumers[0] - first_cast_type = output_type - second_cast_type = utils.get_cast_to_type(next_node) - - return ( - first_cast_type in precision_order - and second_cast_type in precision_order - and precision_order.index(first_cast_type) <= precision_order.index(second_cast_type) - ) - def _remove_redundant_casts(self): """Removes both sequential casts and casts that don't change precision. This method optimizes the graph by removing unnecessary cast operations that either: 1. Don't actually change the data type 2. Could be replaced by a single cast operation + 3. Can be folded into a preceding Constant node """ if self.custom_ops: self.model = self._propagate_types_shapes_custom_ops(self.model) @@ -1198,35 +1169,7 @@ def _remove_redundant_casts(self): check_type=True, ) - nodes_to_remove = [] - for node in self.model.graph.node: - if node.op_type == "Cast": - # Find cast nodes that don't change precision - if self._is_same_type_cast(node): - nodes_to_remove.append(node) - self._bypass_cast_node(node) - logger.debug(f"Found redundant same-type cast: {node.name}") - continue - - # Find sequential casts that don't change precision - if self._is_sequential_cast(node): - nodes_to_remove.append(node) - self._bypass_cast_node(node) - logger.debug(f"Found removable double-cast: {node.name}") - - # Find foldable Constant -> Cast. Initializers are handled by _convert_initializers. - if self._is_foldable_constant_cast_pattern(node): - nodes_to_remove.append(node) - cast_producers = utils.get_producer_nodes(self.model, node.input[0]) - assert len(cast_producers) == 1 and cast_producers[0].op_type == "Constant" - constant_producer = cast_producers[0] - self._convert_constant_values(constant_producer, node) - self._bypass_cast_node(node) - logger.debug(f"Found foldable Constant->Cast pattern, removing {node.name}") - - logger.debug(f"Removing redundant casts: {[n.name for n in nodes_to_remove]}") - for node in nodes_to_remove: - self.model.graph.node.remove(node) + self.model = onnx_utils.remove_redundant_casts(self.model) def _fix_network_output_names(self): modified = False @@ -1360,80 +1303,6 @@ def _get_tensor_type(self, tensor_name): return self.initializer_map[tensor_name].data_type raise Exception(f"did not find tensor {tensor_name}") - def _convert_constant_values(self, const_node, cast_node: onnx.NodeProto) -> None: - original_tensor = const_node.attribute[0].t - if original_tensor.data_type == onnx.TensorProto.BFLOAT16: - original_data = onnx_utils.read_f16_tensor_as_fp32(original_tensor) - else: - original_data = onnx.numpy_helper.to_array(original_tensor) - - # Precompute casted value - cast_to_type = utils.get_cast_to_type(cast_node) - cast_dtype = onnx.helper.tensor_dtype_to_np_dtype(cast_to_type) - - # Handle bfloat16 conversion manually since numpy doesn't support it natively - if cast_to_type == onnx.TensorProto.BFLOAT16: - casted_data = original_data.astype(ml_dtypes.bfloat16) - else: - casted_data = original_data.astype(cast_dtype) - - # Create a new constant node with casted data - if cast_to_type == onnx.TensorProto.BFLOAT16: - # Create TensorProto manually for bfloat16 - tensor_proto = onnx.TensorProto() - tensor_proto.name = const_node.output[0] - tensor_proto.data_type = onnx.TensorProto.BFLOAT16 - tensor_proto.dims.extend(casted_data.shape) - # Convert bfloat16 to raw bytes - bf16_bytes = casted_data.astype(ml_dtypes.bfloat16).view(np.uint16) - tensor_proto.raw_data = bf16_bytes.tobytes() - else: - # Create tensor manually to ensure proper handling - tensor_proto = onnx.numpy_helper.from_array(casted_data) - tensor_proto.name = const_node.output[0] - - new_const_node = onnx.helper.make_node( - "Constant", - inputs=[], - outputs=const_node.output, - value=tensor_proto, - name=const_node.name, - ) - - # Replace the original constant node with the new constant node - # The scope of this function is to convert the constant node data. Removing the cast is done later. - for node in utils.get_consumer_nodes(self.model, const_node.name): - for i, input_name in enumerate(node.input): - if input_name == const_node.name: - node.input[i] = new_const_node.output[0] - break - - const_idx = -1 - for i, node in enumerate(self.model.graph.node): - if node == const_node: - const_idx = i - break - - self.model.graph.node.remove(const_node) - self.model.graph.node.insert(const_idx, new_const_node) - # The Cast node is the sole consumer of the Constant node, guaranteed by _is_foldable_constant_cast_pattern - cast_node.input[0] = new_const_node.output[0] - - def _is_foldable_constant_cast_pattern(self, node: onnx.NodeProto) -> bool: - """Constant -> Cast and Cast is the only consumer of the Constant node.""" - assert node.op_type == "Cast" - - producer = utils.get_producer_nodes(self.model, node.input[0]) - - const_producer = ( - producer[0] if len(producer) == 1 and producer[0].op_type == "Constant" else None - ) - - if const_producer: - get_consumer_nodes = utils.get_consumer_nodes(self.model, const_producer.output[0]) - return len(get_consumer_nodes) == 1 and get_consumer_nodes[0] == node - return False - def _sanitize_model(self): graph_sanitizer = GraphSanitizer( self.model, diff --git a/modelopt/onnx/export/fp8_exporter.py b/modelopt/onnx/export/fp8_exporter.py index 28e6b1da1..ffcbd8942 100644 --- a/modelopt/onnx/export/fp8_exporter.py +++ b/modelopt/onnx/export/fp8_exporter.py @@ -22,6 +22,8 @@ import torch from onnx_graphsurgeon.ir.tensor import LazyValues +from modelopt.onnx.logging_config import logger + from .base_exporter import ONNXQuantExporter @@ -45,13 +47,13 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: Even though modelopt supports FP8 onnx export, the weights are represented in fp32 + QDQ. The storage is therefore very bad. In this function, Q nodes will get removed from the weights and have only DQ nodes with those converted FP8 - weights in the output model. + weights in the output model. TRT custom ops are converted to native ONNX DequantizeLinear. Parameters: - onnx_model: ONNX model with FP32/FP16 weights and QDQ nodes. + onnx_model: ONNX model with FP32/FP16 weights and TRT_FP8 QDQ nodes. Returns: - ONNX model with FP8 weights and only DQ nodes for weights (QDQ preserved for activations). + ONNX model with FP8 weights and native ONNX DQ nodes for weights (QDQ preserved for activations). """ start_time = time.time() print("Replacing all (fp32 weights + fp8 QDQ) with (fp8 weights + DQ)...") @@ -62,7 +64,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: for node in graph.nodes: if node.op == "TRT_FP8QuantizeLinear": - # Should not remove input QDQ + # Should not remove input QDQ (only process weight quantization) if not isinstance(node.inputs[0], gs.Constant): continue @@ -88,7 +90,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: onnx_weights_fp8 = gs.Constant(quantizer_name + "/fp8_weights", values) node.outputs.clear() - # DQ Op is separated out + # Convert TRT DQ to native ONNX DequantizeLinear with FP8 weights dq_op.inputs[0] = onnx_weights_fp8 dq_op.op = "DequantizeLinear" dq_op.outputs[0].dtype = dq_op.inputs[1].dtype @@ -101,5 +103,46 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Post-processes the ONNX model for FP8 quantization.""" - return onnx_model + """Post-processes the ONNX model for FP8 quantization. + + Converts TRT_FP8 QDQ ops to native ONNX QuantizeLinear/DequantizeLinear: + - TRT_FP8QuantizeLinear -> QuantizeLinear with FP8E4M3FN zero_point and saturate=1 + - TRT_FP8DequantizeLinear -> DequantizeLinear + + Args: + onnx_model: The ONNX model containing TRT_FP8 quantization nodes. + + Returns: + The post-processed ONNX model with native ONNX quantization ops. + """ + logger.info("Post-processing FP8 quantized model") + graph = gs.import_onnx(onnx_model) + + # Convert TRT_FP8QuantizeLinear to native QuantizeLinear + for node in graph.nodes: + if node.op == "TRT_FP8QuantizeLinear": + node.op = "QuantizeLinear" + # Add FP8 zero_point if not present + if len(node.inputs) == 2: + # Create FP8 zero point constant + zp_tensor = onnx.TensorProto() + zp_tensor.data_type = onnx.TensorProto.FLOAT8E4M3FN + zp_tensor.dims.extend([1]) # 1-element tensor + zp_tensor.raw_data = b"\x00" # Zero in FP8 + zp_values = LazyValues(zp_tensor) + zero_point = gs.Constant(node.name + "_zero_point", zp_values) + node.inputs.append(zero_point) + # Add saturate attribute for FP8 + node.attrs["saturate"] = 1 + logger.debug(f"Converted {node.name} from TRT_FP8QuantizeLinear to QuantizeLinear") + + # Convert TRT_FP8DequantizeLinear to native DequantizeLinear + for node in graph.nodes: + if node.op == "TRT_FP8DequantizeLinear": + node.op = "DequantizeLinear" + logger.debug( + f"Converted {node.name} from TRT_FP8DequantizeLinear to DequantizeLinear" + ) + + graph.cleanup().toposort() + return gs.export_onnx(graph) diff --git a/modelopt/onnx/export/nvfp4_exporter.py b/modelopt/onnx/export/nvfp4_exporter.py index 416c2fdf8..a80a9845f 100644 --- a/modelopt/onnx/export/nvfp4_exporter.py +++ b/modelopt/onnx/export/nvfp4_exporter.py @@ -215,7 +215,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to process") for node in fp4_qdq_nodes: - idx = initializer_indices.get(node.input[0], None) + idx = initializer_indices.get(node.input[0]) assert idx is not None, f"Initializer for weight '{node.input[0]}' not found." tensor = initializers[idx] @@ -259,7 +259,7 @@ def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: fp4_qdq_nodes = [node for node in graph.node if node.op_type == "TRT_FP4QDQ"] for node in fp4_qdq_nodes: - idx = initializer_indices.get(node.input[0], None) + idx = initializer_indices.get(node.input[0]) assert idx is not None, f"Initializer for weight '{node.input[0]}' not found." tensor = initializers[idx] @@ -365,7 +365,7 @@ def _cast_input_dtypes(node: onnx.NodeProto, precision_dtype: str): logger.debug(f"Found {len(fp4_qdq_nodes)} FP4QDQ nodes to convert") for node in fp4_qdq_nodes: - idx = initializer_indices.get(node.input[0], None) + idx = initializer_indices.get(node.input[0]) assert idx is not None, f"Initializer for weight '{node.input[0]}' not found." initializers_to_delete.append(graph.initializer[idx].name) diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 4025ea065..eb2ae27ff 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -433,8 +433,8 @@ def randomize_weights_onnx_bytes(onnx_bytes: bytes, seed: int = 0) -> bytes: if len(init.dims) > 1: dtype = onnx.helper.tensor_dtype_to_np_dtype(init.data_type) if dtype in ["float16", "float32", "float64"]: - avg = weight_metadata.get(init.name + "_avg", None) - var = weight_metadata.get(init.name + "_var", None) + avg = weight_metadata.get(init.name + "_avg") + var = weight_metadata.get(init.name + "_var") if avg and var: numpy_array = np.random.normal(float(avg), float(var), size=init.dims).astype( dtype @@ -1215,6 +1215,110 @@ def onnx_type_str_to_enum(dtype: str) -> int: return getattr(onnx.TensorProto, dtype) +def remove_redundant_casts(onnx_model: onnx.ModelProto) -> onnx.ModelProto: + """Removes redundant Cast nodes from an ONNX model. + + Handles three patterns: + 1. Same-type casts: Cast where input type == output type (no-op) + 2. Sequential casts: Cast(to=high_prec) -> Cast(to=low_prec), first cast removed + 3. Constant->Cast folding: Fold cast into preceding Constant node's data + + Args: + onnx_model: The ONNX model to optimize. + + Returns: + onnx.ModelProto: Model with redundant casts removed. + """ + import ml_dtypes + + graph = gs.import_onnx(onnx_model) + removed_count = 0 + + # Precision ordering: lower index = higher precision + precision_order = { + onnx.TensorProto.DOUBLE: 0, + onnx.TensorProto.FLOAT: 1, + onnx.TensorProto.FLOAT16: 2, + onnx.TensorProto.BFLOAT16: 3, + } + + def _get_onnx_type(tensor): + """Get ONNX type enum from a GS tensor's dtype.""" + if tensor.dtype is None: + return None + try: + return onnx.helper.np_dtype_to_tensor_dtype(tensor.dtype) + except Exception: + return None + + def _bypass_cast(node): + """Reconnect consumers of cast output to use cast input, removing the cast.""" + inp = node.inputs[0] + out = node.outputs[0] + for consumer in list(out.outputs): + for i, consumer_inp in enumerate(consumer.inputs): + if consumer_inp is out: + consumer.inputs[i] = inp + for i, graph_out in enumerate(graph.outputs): + if graph_out is out: + graph.outputs[i] = inp + node.outputs.clear() + + for node in list(graph.nodes): + if node.op != "Cast": + continue + + cast_to = node.attrs.get("to") + if cast_to is None: + continue + + input_tensor = node.inputs[0] + output_tensor = node.outputs[0] + + # Pattern 1: Same-type cast (no-op) + input_type = _get_onnx_type(input_tensor) + if input_type is not None and input_type == cast_to: + _bypass_cast(node) + removed_count += 1 + logger.debug(f"Removed same-type cast: {node.name}") + continue + + # Pattern 2: Sequential casts where first can be removed + # Cast(to=high) -> Cast(to=low): first cast has no effect + cast_consumers = output_tensor.outputs + if len(cast_consumers) == 1 and cast_consumers[0].op == "Cast": + next_cast_to = cast_consumers[0].attrs.get("to") + if ( + cast_to in precision_order + and next_cast_to in precision_order + and precision_order[cast_to] <= precision_order[next_cast_to] + ): + _bypass_cast(node) + removed_count += 1 + logger.debug(f"Removed sequential cast: {node.name}") + continue + + # Pattern 3: Constant -> Cast folding (only if constant has single consumer) + if isinstance(input_tensor, Constant) and len(input_tensor.outputs) == 1: + try: + if cast_to == onnx.TensorProto.BFLOAT16: + input_tensor.values = input_tensor.values.astype(ml_dtypes.bfloat16) + else: + cast_dtype = onnx.helper.tensor_dtype_to_np_dtype(cast_to) + input_tensor.values = input_tensor.values.astype(cast_dtype) + _bypass_cast(node) + removed_count += 1 + logger.debug(f"Folded Constant->Cast: {node.name}") + except Exception as e: + logger.debug(f"Failed to fold Constant->Cast {node.name}: {e}") + + if removed_count > 0: + graph.cleanup().toposort() + logger.info(f"Removed {removed_count} redundant Cast nodes") + + return gs.export_onnx(graph) + + def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: """Remove `training_mode` attribute and extra training outputs from nodes of a given op type. @@ -1263,3 +1367,43 @@ def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx_model.graph.value_info.extend(keep) return onnx_model + + +def change_casts_to_fp16(model: onnx.ModelProto, target_op_types: list[str]) -> onnx.ModelProto: + """Change Cast nodes that cast to FP32 and feed into specified nodes to cast to FP16 instead. + + Args: + model: The ONNX model to modify. + target_op_types: List of op types to check for. Cast nodes feeding into these will be + changed from FP32 to FP16. + + Returns: + The modified ONNX model with Cast nodes updated. + """ + # Build a map of tensor name -> consumer nodes + tensor_to_consumers: dict[str, list[onnx.NodeProto]] = {} + for node in model.graph.node: + for inp in node.input: + if inp: + tensor_to_consumers.setdefault(inp, []).append(node) + + # Find Cast nodes that feed into target ops and change FP32 -> FP16 + for node in model.graph.node: + if node.op_type != "Cast": + continue + + # Check if this Cast outputs to a target op type + cast_output = node.output[0] + consumers = tensor_to_consumers.get(cast_output, []) + feeds_target = any(c.op_type in target_op_types for c in consumers) + + if not feeds_target: + continue + + # Check if Cast is to FP32, and change to FP16 + for attr in node.attribute: + if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT: + attr.i = onnx.TensorProto.FLOAT16 + break + + return model diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 304fb8ec7..b89a7a4e9 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -21,17 +21,17 @@ import os import shutil import tempfile -from contextlib import nullcontext +from contextlib import nullcontext, suppress from typing import Any import onnx +import onnxconverter_common.float16 as _f16_module import torch import torch.nn as nn from onnx import ModelProto from onnxconverter_common import convert_float_to_float16 from torch.nn.parallel import DataParallel, DistributedDataParallel -from modelopt.onnx.autocast.convert import convert_to_f16 from modelopt.onnx.export import ( FP8QuantExporter, INT4QuantExporter, @@ -42,6 +42,7 @@ ) from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero from modelopt.onnx.utils import ( + change_casts_to_fp16, check_model_uses_external_data, get_input_names, get_input_shapes, @@ -50,6 +51,7 @@ get_output_shapes, infer_shapes, remove_node_training_mode, + remove_redundant_casts, ) from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers from modelopt.torch.utils import flatten_tree, standardize_named_model_args @@ -57,6 +59,17 @@ from ..utils.onnx_optimizer import Optimizer +# Monkey-patch to fix onnxconverter_common bug where downstream_node is a list +_original_remove_unnecessary_cast_node = _f16_module.remove_unnecessary_cast_node + + +def _patched_remove_unnecessary_cast_node(graph): + with suppress(AttributeError): + _original_remove_unnecessary_cast_node(graph) + + +_f16_module.remove_unnecessary_cast_node = _patched_remove_unnecessary_cast_node + ModelMetadata = dict[str, Any] ModelType = Any ValueInfoType = Any @@ -560,38 +573,30 @@ def get_onnx_bytes_and_metadata( tree_spec_input, tree_spec_output, input_none_names, onnx_opt_graph, model ) - # TODO: Remove manual ir_version change once ORT supports ir_version 11 - onnx_opt_graph.ir_version = 10 - - # Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode - # Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode - if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model): - onnx_opt_graph = quantize_weights(model, onnx_opt_graph) + onnx_opt_graph = quantize_weights(model, onnx_opt_graph) if dq_only: onnx_opt_graph = qdq_to_dq(onnx_opt_graph) - try: - # TODO: Single-precision torch model assumed - param_dtype = next(model.parameters()).dtype - except StopIteration: - param_dtype = torch.float32 - if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: - if is_int4_quantized(model) or is_mxfp8_quantized(model): - assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" - onnx_opt_graph = convert_float_to_float16( - onnx_opt_graph, - keep_io_types=False, - disable_shape_infer=True, - check_fp16_ready=False, - ) - else: - onnx_opt_graph = convert_to_f16( - onnx_opt_graph, low_precision_type=weights_dtype, keep_io_types=False - ) + if weights_dtype == "fp16": + onnx_opt_graph = convert_float_to_float16( + onnx_opt_graph, + keep_io_types=False, + disable_shape_infer=True, + check_fp16_ready=False, + op_block_list=["QuantizeLinear", "DequantizeLinear", "Div"], + ) + # Change FP32 cast nodes feeding into Concat/Add to FP16 + onnx_opt_graph = change_casts_to_fp16(onnx_opt_graph, ["Concat", "Add"]) - # TensorRT expects all scales to be postive - onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) + onnx_opt_graph = remove_redundant_casts(onnx_opt_graph) + + # TensorRT expects all scales to be postive + onnx_opt_graph = replace_zero_scale_with_smallest_nonzero(onnx_opt_graph) + + # TODO: Remove manual ir_version change once ORT supports ir_version 11 + # Must be set after all gs.export_onnx() calls as graphsurgeon resets ir_version + onnx_opt_graph.ir_version = 10 # If the onnx model contains external data store the external tensors in one file and save the onnx model if has_external_data(onnx_save_path):