From 03175ec63c4b075d8b045bf9862b0e72d2563daa Mon Sep 17 00:00:00 2001 From: shangkunwang Date: Thu, 23 Apr 2026 22:12:51 +0000 Subject: [PATCH] refactor: separate xla_module and xla_op events in extract_xprof_time function --- MaxKernel/evaluation/xprof_utils.py | 45 +++++++++++++++++------------ 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/MaxKernel/evaluation/xprof_utils.py b/MaxKernel/evaluation/xprof_utils.py index 18a5646..8bc4274 100644 --- a/MaxKernel/evaluation/xprof_utils.py +++ b/MaxKernel/evaluation/xprof_utils.py @@ -38,8 +38,9 @@ def extract_xprof_time( total_duration_ps = 0 count = 0 - matched_events = [] + xla_module_events = [] + xla_op_events = [] for file_path in xplane_files: try: with open(file_path, "rb") as f: @@ -50,37 +51,43 @@ def extract_xprof_time( if "/device:TPU:0" not in plane.name: continue for line in plane.lines: + if line.name not in ("XLA Modules", "XLA Ops"): + continue for event in line.events: name = "" if event.metadata_id in plane.event_metadata: name = plane.event_metadata[event.metadata_id].name if event_name in name: - total_duration_ps += event.duration_ps - count += 1 - matched_events.append( - { - "file": file_path, - "plane": plane.name, - "name": name, - "duration_ps": event.duration_ps, - } - ) - + if line.name == "XLA Modules": + xla_module_events.append( + { + "file": file_path, + "plane": plane.name, + "name": name, + "duration_ps": event.duration_ps, + } + ) + elif name.startswith("%benchmark_func"): + xla_op_events.append( + { + "file": file_path, + "plane": plane.name, + "name": name, + "duration_ps": event.duration_ps, + } + ) except Exception as e: logging.warning(f"Failed to parse {file_path}: {e}") - # Filter for kernel events (starting with %) - kernel_events = [ev for ev in matched_events if ev["name"].startswith("%")] - - if kernel_events: + if xla_op_events: logging.info("Found kernel events starting with %. Using them.") - target_events = kernel_events + target_events = xla_op_events else: logging.info( - "No kernel events found. Falling back to all matched events (likely wrappers)." + "No xla_op_events events found. Falling back to xla_module_events matched events." ) - target_events = matched_events + target_events = xla_module_events total_duration_ps = sum(ev["duration_ps"] for ev in target_events) count = len(target_events)