Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions core.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,34 +815,31 @@ def _do_match(self, node, optimizer, context):
if len(data_inputs) != 2 or len(self.inputs) != 2:
return False

# Instead of full copy, just swap temporarily and swap back
original_inputs = list(node.input)

# Find indices of data inputs
data_indices = [
i for i, name in enumerate(node.input) if not name.startswith("^")
]
if len(data_indices) != 2:
# Bolt ⚡: This optimization avoids expensive in-place modification of the
# graph node's inputs and list creation. Instead, we temporarily swap
# the pattern's inputs and match against the original node.
if len(self.inputs) != 2:
return False

# Swap them
i1, i2 = data_indices[0], data_indices[1]
node.input[i1], node.input[i2] = original_inputs[i2], original_inputs[i1]

try:
# Try matching with swapped inputs
saved_nodes = context.matched_nodes.copy()
saved_all = context.all_matched_nodes.copy()

res = super()._do_match(node, optimizer, context)
if not res:
# Restore context if match failed
context.matched_nodes = saved_nodes
context.all_matched_nodes = saved_all
return res
finally:
# ALWAYS restore original inputs
node.input[i1], node.input[i2] = original_inputs[i1], original_inputs[i2]
# Temporarily swap pattern inputs
self.inputs[0], self.inputs[1] = self.inputs[1], self.inputs[0]

# Try matching with swapped pattern
# Create a clean context for the swapped match attempt
clean_context = MatchContext()
res = super()._do_match(node, optimizer, clean_context)

# ALWAYS restore original pattern input order
self.inputs[0], self.inputs[1] = self.inputs[1], self.inputs[0]

if res:
# If swapped match was successful, merge the results into the main context
context.matched_nodes.update(clean_context.matched_nodes)
context.all_matched_nodes.update(clean_context.all_matched_nodes)
context.control_inputs.update(clean_context.control_inputs)
return True

return False


def CommutativeOp(
Expand Down
4 changes: 2 additions & 2 deletions demos/run_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def main():
output_path = "demos/graph_def_rankmixer_optimized.pb"

# Generate the complex graph
# print(f"Generating graph to {input_path}...")
# create_complex_concat_graph(input_path)
print(f"Generating graph to {input_path}...")
create_complex_concat_graph(input_path)

# 2. Evaluate original graph
print("Evaluating original graph...")
Expand Down