Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .jules/bolt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## 2026-01-31 - [O(1) Pattern Matching vs O(N) Wildcards]
**Learning:** Using a general `Any()` wildcard pattern in a rewrite pass that only targets specific operations is a performance anti-pattern. It forces the optimizer to call the rewriter for EVERY node in the graph, resulting in O(N) complexity. By registering specific `Op()` patterns, the optimizer can use its internal indexing for O(1) matching.
**Action:** Always prefer specific `Op()` patterns over `Any()` wildcards when the set of target operations is known. Refactor `PatternRewritePass` to support multiple specific patterns if a pass targets multiple op types.

## 2026-01-31 - [Const Node Shape Discovery]
**Learning:** TensorFlow `Const` nodes often store their shape information within the `tensor` proto of the `value` attribute, rather than a top-level `shape` attribute or `_output_shapes`. Standard shape extraction utilities must explicitly check the `tensor_shape` field within the `value` attribute's tensor to avoid losing shape information for constants.
**Action:** Ensure `get_node_shape` utilities handle the `Const` value tensor as a fallback for shape discovery.
38 changes: 34 additions & 4 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,12 @@ def get_node_shape(self, node_or_name):
if "shape" in node.attr:
return [dim.size for dim in node.attr["shape"].shape.dim]

# Handle Const nodes by reading shape from the value attribute's tensor
if node.op == "Const" and "value" in node.attr:
tensor = node.attr["value"].tensor
if tensor.HasField("tensor_shape"):
return [dim.size for dim in tensor.tensor_shape.dim]

return None

def get_node_rank(self, node_or_name):
Expand Down Expand Up @@ -1516,12 +1522,35 @@ class PatternRewritePass(BasePass):
for the actual pattern matching. Iterates until convergence (no more matches).
"""

def __init__(self, pattern, rewriter, name=None, optimizer_alias=None):
def __init__(self, patterns=None, rewriter=None, name=None, optimizer_alias=None, **kwargs):
"""
Initialize a pattern rewrite pass.

Args:
patterns: A single Pattern or a list of Patterns to match. (Also accepts 'pattern' for compat)
rewriter: The rewriter function to execute on matches.
name: Human-readable pass name.
optimizer_alias: Short alias for node naming.
"""
# Support legacy 'pattern' keyword argument
if patterns is None:
patterns = kwargs.pop("pattern", None)

if patterns is None:
raise ValueError("At least one pattern must be provided to PatternRewritePass")

# Use iterative mode - run until convergence
super().__init__(name, optimizer_alias, iterative=True, max_iterations=100)
self.pattern = pattern
if not isinstance(patterns, list):
patterns = [patterns]
self.patterns = patterns
self.rewriter = trace_transformation(rewriter)

@property
def pattern(self):
"""Property for backward compatibility (returns the first pattern)."""
return self.patterns[0] if self.patterns else None

def transform_once(
self,
optimizer: GraphOptimizer,
Expand All @@ -1534,9 +1563,10 @@ def transform_once(
Returns:
int: Number of changes made
"""
# Register the pattern (clear first to avoid duplicates)
# Register all patterns (clear first to avoid duplicates)
optimizer.clear_transformations()
optimizer.add_transformation(self.pattern, self.rewriter)
for pattern in self.patterns:
optimizer.add_transformation(pattern, self.rewriter)

# Run one pattern matching iteration
new_graph_def, changes = optimizer.match_patterns_once(
Expand Down
27 changes: 11 additions & 16 deletions transforms/scalar/algebraic_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,15 @@ class AlgebraicSimplifyPass(PatternRewritePass):
"""

def __init__(self):
# We'll handle multiple patterns manually in _rewrite
pattern = Any(alias="op") # fallback, we check inside
super().__init__(pattern, self._rewrite, name="AlgebraicSimplify")
# Specific operations supported by this pass
supported_ops = [
"Add", "Sub", "Mul", "Div", "Neg", "LogicalNot", "Abs", "Square", "Sqrt", "Pow",
"Equal", "NotEqual", "Less", "Greater", "LessEqual", "GreaterEqual",
"LogicalAnd", "LogicalOr", "Select", "Identity"
]
# Register specific Op patterns instead of a catch-all Any() for O(1) matching
patterns = [Op(op, alias="op") for op in supported_ops]
super().__init__(patterns, self._rewrite, name="AlgebraicSimplify")

def _rewrite(self, match, optimizer):
node = match.matched_nodes["op"]
Expand Down Expand Up @@ -127,20 +133,9 @@ def _is_const(node_name, value):
# Check if all elements are equal to the target value
return np.all(np.equal(val, value))

# Helper to get shape of a node
# Use improved core.py utility for shape discovery
def _get_shape(node_name):
node = _get_node(node_name)
if node is None:
return None
# Check for shape attribute (Placeholder, etc.)
if "shape" in node.attr:
return [d.size for d in node.attr["shape"].shape.dim]
# Check for Const value shape
if node.op == "Const" and "value" in node.attr:
tensor = node.attr["value"].tensor
if tensor.HasField("tensor_shape"):
return [d.size for d in tensor.tensor_shape.dim]
return None
return optimizer.get_node_shape(node_name)

# Helper to check if a node is definitely scalar
def _is_scalar(node_name):
Expand Down
15 changes: 12 additions & 3 deletions transforms/scalar/constant_fold.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,18 @@ class ConstantFoldPass(PatternRewritePass):
"""

def __init__(self):
# Matches any operation with all inputs as Const
pattern = Any(alias="op")
super().__init__(pattern, self._rewrite_constant_op, name="ConstantFold")
# Specific operations supported for constant folding
foldable_ops = [
"Add", "Mul", "Sub", "Div", "Neg", "Equal", "NotEqual", "Less", "Greater",
"LessEqual", "GreaterEqual", "LogicalAnd", "LogicalOr", "LogicalNot",
"BitwiseAnd", "BitwiseOr", "BitwiseXor", "Abs", "Exp", "Expm1", "Log",
"Log1p", "Sqrt", "Pow", "Rsqrt", "Square", "Sin", "Cos", "Tan", "Asin",
"Acos", "Atan", "Atan2", "Floor", "Ceil", "Round", "Sign",
"Reshape", "Transpose", "ConcatV2", "Select", "Cast"
]
# Register specific Op patterns instead of a catch-all Any() for O(1) matching
patterns = [Op(op, alias="op") for op in foldable_ops]
super().__init__(patterns, self._rewrite_constant_op, name="ConstantFold")

def _is_all_const(self, inputs, optimizer):
"""Check if all inputs are Const nodes.
Expand Down