diff --git a/.jules/bolt.md b/.jules/bolt.md new file mode 100644 index 0000000..e7a488f --- /dev/null +++ b/.jules/bolt.md @@ -0,0 +1,7 @@ +## 2026-01-31 - [O(1) Pattern Matching vs O(N) Wildcards] +**Learning:** Using a general `Any()` wildcard pattern in a rewrite pass that only targets specific operations is a performance anti-pattern. It forces the optimizer to call the rewriter for EVERY node in the graph, resulting in O(N) complexity. By registering specific `Op()` patterns, the optimizer can use its internal indexing for O(1) matching. +**Action:** Always prefer specific `Op()` patterns over `Any()` wildcards when the set of target operations is known. Refactor `PatternRewritePass` to support multiple specific patterns if a pass targets multiple op types. + +## 2026-01-31 - [Const Node Shape Discovery] +**Learning:** TensorFlow `Const` nodes often store their shape information within the `tensor` proto of the `value` attribute, rather than a top-level `shape` attribute or `_output_shapes`. Standard shape extraction utilities must explicitly check the `tensor_shape` field within the `value` attribute's tensor to avoid losing shape information for constants. +**Action:** Ensure `get_node_shape` utilities handle the `Const` value tensor as a fallback for shape discovery. diff --git a/core.py b/core.py index 25d8232..ebabb00 100644 --- a/core.py +++ b/core.py @@ -321,6 +321,12 @@ def get_node_shape(self, node_or_name): if "shape" in node.attr: return [dim.size for dim in node.attr["shape"].shape.dim] + # Handle Const nodes by reading shape from the value attribute's tensor + if node.op == "Const" and "value" in node.attr: + tensor = node.attr["value"].tensor + if tensor.HasField("tensor_shape"): + return [dim.size for dim in tensor.tensor_shape.dim] + return None def get_node_rank(self, node_or_name): @@ -1516,12 +1522,35 @@ class PatternRewritePass(BasePass): for the actual pattern matching. Iterates until convergence (no more matches). """ - def __init__(self, pattern, rewriter, name=None, optimizer_alias=None): + def __init__(self, patterns=None, rewriter=None, name=None, optimizer_alias=None, **kwargs): + """ + Initialize a pattern rewrite pass. + + Args: + patterns: A single Pattern or a list of Patterns to match. (Also accepts 'pattern' for compat) + rewriter: The rewriter function to execute on matches. + name: Human-readable pass name. + optimizer_alias: Short alias for node naming. + """ + # Support legacy 'pattern' keyword argument + if patterns is None: + patterns = kwargs.pop("pattern", None) + + if patterns is None: + raise ValueError("At least one pattern must be provided to PatternRewritePass") + # Use iterative mode - run until convergence super().__init__(name, optimizer_alias, iterative=True, max_iterations=100) - self.pattern = pattern + if not isinstance(patterns, list): + patterns = [patterns] + self.patterns = patterns self.rewriter = trace_transformation(rewriter) + @property + def pattern(self): + """Property for backward compatibility (returns the first pattern).""" + return self.patterns[0] if self.patterns else None + def transform_once( self, optimizer: GraphOptimizer, @@ -1534,9 +1563,10 @@ def transform_once( Returns: int: Number of changes made """ - # Register the pattern (clear first to avoid duplicates) + # Register all patterns (clear first to avoid duplicates) optimizer.clear_transformations() - optimizer.add_transformation(self.pattern, self.rewriter) + for pattern in self.patterns: + optimizer.add_transformation(pattern, self.rewriter) # Run one pattern matching iteration new_graph_def, changes = optimizer.match_patterns_once( diff --git a/transforms/scalar/algebraic_simplify.py b/transforms/scalar/algebraic_simplify.py index 76c94af..518ffd7 100644 --- a/transforms/scalar/algebraic_simplify.py +++ b/transforms/scalar/algebraic_simplify.py @@ -87,9 +87,15 @@ class AlgebraicSimplifyPass(PatternRewritePass): """ def __init__(self): - # We'll handle multiple patterns manually in _rewrite - pattern = Any(alias="op") # fallback, we check inside - super().__init__(pattern, self._rewrite, name="AlgebraicSimplify") + # Specific operations supported by this pass + supported_ops = [ + "Add", "Sub", "Mul", "Div", "Neg", "LogicalNot", "Abs", "Square", "Sqrt", "Pow", + "Equal", "NotEqual", "Less", "Greater", "LessEqual", "GreaterEqual", + "LogicalAnd", "LogicalOr", "Select", "Identity" + ] + # Register specific Op patterns instead of a catch-all Any() for O(1) matching + patterns = [Op(op, alias="op") for op in supported_ops] + super().__init__(patterns, self._rewrite, name="AlgebraicSimplify") def _rewrite(self, match, optimizer): node = match.matched_nodes["op"] @@ -127,20 +133,9 @@ def _is_const(node_name, value): # Check if all elements are equal to the target value return np.all(np.equal(val, value)) - # Helper to get shape of a node + # Use improved core.py utility for shape discovery def _get_shape(node_name): - node = _get_node(node_name) - if node is None: - return None - # Check for shape attribute (Placeholder, etc.) - if "shape" in node.attr: - return [d.size for d in node.attr["shape"].shape.dim] - # Check for Const value shape - if node.op == "Const" and "value" in node.attr: - tensor = node.attr["value"].tensor - if tensor.HasField("tensor_shape"): - return [d.size for d in tensor.tensor_shape.dim] - return None + return optimizer.get_node_shape(node_name) # Helper to check if a node is definitely scalar def _is_scalar(node_name): diff --git a/transforms/scalar/constant_fold.py b/transforms/scalar/constant_fold.py index 187cff4..513f145 100644 --- a/transforms/scalar/constant_fold.py +++ b/transforms/scalar/constant_fold.py @@ -58,9 +58,18 @@ class ConstantFoldPass(PatternRewritePass): """ def __init__(self): - # Matches any operation with all inputs as Const - pattern = Any(alias="op") - super().__init__(pattern, self._rewrite_constant_op, name="ConstantFold") + # Specific operations supported for constant folding + foldable_ops = [ + "Add", "Mul", "Sub", "Div", "Neg", "Equal", "NotEqual", "Less", "Greater", + "LessEqual", "GreaterEqual", "LogicalAnd", "LogicalOr", "LogicalNot", + "BitwiseAnd", "BitwiseOr", "BitwiseXor", "Abs", "Exp", "Expm1", "Log", + "Log1p", "Sqrt", "Pow", "Rsqrt", "Square", "Sin", "Cos", "Tan", "Asin", + "Acos", "Atan", "Atan2", "Floor", "Ceil", "Round", "Sign", + "Reshape", "Transpose", "ConcatV2", "Select", "Cast" + ] + # Register specific Op patterns instead of a catch-all Any() for O(1) matching + patterns = [Op(op, alias="op") for op in foldable_ops] + super().__init__(patterns, self._rewrite_constant_op, name="ConstantFold") def _is_all_const(self, inputs, optimizer): """Check if all inputs are Const nodes.