diff --git a/core.py b/core.py index 0433354..c95e686 100644 --- a/core.py +++ b/core.py @@ -815,34 +815,31 @@ def _do_match(self, node, optimizer, context): if len(data_inputs) != 2 or len(self.inputs) != 2: return False - # Instead of full copy, just swap temporarily and swap back - original_inputs = list(node.input) - - # Find indices of data inputs - data_indices = [ - i for i, name in enumerate(node.input) if not name.startswith("^") - ] - if len(data_indices) != 2: + # Bolt ⚡: This optimization avoids expensive in-place modification of the + # graph node's inputs and list creation. Instead, we temporarily swap + # the pattern's inputs and match against the original node. + if len(self.inputs) != 2: return False - # Swap them - i1, i2 = data_indices[0], data_indices[1] - node.input[i1], node.input[i2] = original_inputs[i2], original_inputs[i1] - - try: - # Try matching with swapped inputs - saved_nodes = context.matched_nodes.copy() - saved_all = context.all_matched_nodes.copy() - - res = super()._do_match(node, optimizer, context) - if not res: - # Restore context if match failed - context.matched_nodes = saved_nodes - context.all_matched_nodes = saved_all - return res - finally: - # ALWAYS restore original inputs - node.input[i1], node.input[i2] = original_inputs[i1], original_inputs[i2] + # Temporarily swap pattern inputs + self.inputs[0], self.inputs[1] = self.inputs[1], self.inputs[0] + + # Try matching with swapped pattern + # Create a clean context for the swapped match attempt + clean_context = MatchContext() + res = super()._do_match(node, optimizer, clean_context) + + # ALWAYS restore original pattern input order + self.inputs[0], self.inputs[1] = self.inputs[1], self.inputs[0] + + if res: + # If swapped match was successful, merge the results into the main context + context.matched_nodes.update(clean_context.matched_nodes) + context.all_matched_nodes.update(clean_context.all_matched_nodes) + context.control_inputs.update(clean_context.control_inputs) + return True + + return False def CommutativeOp( diff --git a/demos/run_demo.py b/demos/run_demo.py index 2deba89..23f8536 100644 --- a/demos/run_demo.py +++ b/demos/run_demo.py @@ -58,8 +58,8 @@ def main(): output_path = "demos/graph_def_rankmixer_optimized.pb" # Generate the complex graph - # print(f"Generating graph to {input_path}...") - # create_complex_concat_graph(input_path) + print(f"Generating graph to {input_path}...") + create_complex_concat_graph(input_path) # 2. Evaluate original graph print("Evaluating original graph...")