Add input() and add_output() methods to GraphBuilder#2828
Add input() and add_output() methods to GraphBuilder#2828justinchuby wants to merge 4 commits intomainfrom
Conversation
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This pull request enhances the GraphBuilder class by adding two new convenience methods for managing graph inputs and outputs. The changes simplify the process of creating graph inputs with rich metadata and adding outputs with optional renaming, improving the ergonomics of the graph-building API.
Changes:
- Added
input()method toGraphBuilderfor creating and registering graph input values with support for dtype, shape, type, const_value, and metadata properties - Added
add_output()method toGraphBuilderfor appending output values to the graph with optional renaming - Added comprehensive unit tests for both new methods, plus additional test coverage for the existing
initializer()method's qualification behavior
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
onnxscript/_internal/builder.py |
Implements the new input() and add_output() methods on the GraphBuilder class with full parameter support and documentation |
onnxscript/_internal/builder_test.py |
Adds three new test methods covering the behavior of input(), add_output(), and initializer() qualification |
| def test_add_output_renames_and_registers_output(self): | ||
| """Test that GraphBuilder.add_output renames (optionally) and appends outputs.""" | ||
| graph = ir.Graph( | ||
| name="test_model", | ||
| inputs=[], | ||
| outputs=[], | ||
| nodes=[], | ||
| opset_imports={"": _default_opset_version}, | ||
| ) | ||
| graph_builder = builder.GraphBuilder(graph) | ||
|
|
||
| output = ir.Value(name="old_name") | ||
| graph_builder.add_output(output, "new_name") | ||
|
|
||
| self.assertEqual(output.name, "new_name") | ||
| self.assertEqual(len(graph.outputs), 1) | ||
| self.assertIs(graph.outputs[0], output) |
There was a problem hiding this comment.
The test only verifies the case where a name is provided. Consider adding a test case where name=None is passed to ensure the method correctly handles the case where no renaming is needed, as documented in the method's docstring.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2828 +/- ##
==========================================
+ Coverage 71.78% 71.83% +0.05%
==========================================
Files 239 239
Lines 28989 29043 +54
Branches 2859 2861 +2
==========================================
+ Hits 20809 20864 +55
+ Misses 7209 7208 -1
Partials 971 971 ☔ View full report in Codecov by Sentry. |
| def input( | ||
| self, | ||
| name: str, | ||
| dtype: ir.DataType | None = None, |
There was a problem hiding this comment.
I was thinking about this too. I think if we could accommodate something like FLOAT['N', 1024] as a way of compactly specifying type and shape, it would help. Like it is done here.
There was a problem hiding this comment.
I was hesitant to pass in generic-looking type objects around. The behavior of generic type classes tend to be a bit unstable across different python versions (where things are stored, how data can be accessed, when something is evaluated, etc.). So I am not preferring it for now.
There was a problem hiding this comment.
What is unstable? Maybe we can fix it. Or support something like (dtype, ('N', 1024))
There was a problem hiding this comment.
In summary: the suggestions are
- a combined TypeAndShape would be useful and more compact (in some settings, not necessarily all).
- we could support the legacy onnxscript notation by adding a method to convert it to an ir.TypeAndShape(), and allowing objects that support a method called "toTypeAndShape" ... I currently have a to_ir method, but that name is too generic to be used safely for this purpose, but a name like "toTypeAndShape" should be reasonable.
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
This pull request adds new methods to the
GraphBuilderclass to simplify the creation and management of graph inputs and outputs, and introduces corresponding unit tests to ensure their correct behavior. The changes improve the usability and reliability of the graph-building API.Enhancements to the GraphBuilder API:
inputmethod toGraphBuilderfor creating and registering graph input values with support for specifying name, dtype, shape, type, constant value, and metadata properties.add_outputmethod toGraphBuilderto append an output value to the graph and optionally rename it.