Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit d1d4051

Browse files
authored
Keep user-provided dynamic axes (#1442) (#1447)
1 parent d6f3f93 commit d1d4051

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

src/sparseml/pytorch/utils/exporter.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -492,19 +492,20 @@ def export_onnx(
492492
if "output_names" not in export_kwargs:
493493
export_kwargs["output_names"] = _get_output_names(out)
494494

495+
# Set all batch sizes to be dynamic
495496
if dynamic_axes is not None:
496-
warnings.warn(
497-
"`dynamic_axes` is deprecated and does not affect anything. "
498-
"The 0th axis is always treated as dynamic.",
499-
category=DeprecationWarning,
500-
)
501-
502-
dynamic_axes = {
503-
tensor_name: {0: "batch"}
504-
for tensor_name in (
505-
export_kwargs["input_names"] + export_kwargs["output_names"]
506-
)
507-
}
497+
for tensor_name in export_kwargs["input_names"] + export_kwargs["output_names"]:
498+
if tensor_name not in dynamic_axes:
499+
dynamic_axes[tensor_name] = {0: "batch"}
500+
else:
501+
dynamic_axes[tensor_name][0] = "batch"
502+
else:
503+
dynamic_axes = {
504+
tensor_name: {0: "batch"}
505+
for tensor_name in (
506+
export_kwargs["input_names"] + export_kwargs["output_names"]
507+
)
508+
}
508509

509510
# disable active quantization observers because they cannot be exported
510511
disabled_observers = []

0 commit comments

Comments
 (0)