diff --git a/Fuser/auto_agent.py b/Fuser/auto_agent.py index dad14a6..eacc290 100644 --- a/Fuser/auto_agent.py +++ b/Fuser/auto_agent.py @@ -344,6 +344,7 @@ def __init__( dispatch_jobs: int = 2, allow_fallback: bool = True, target_platform: str | None = None, + ignore_router_config: bool = False, use_router_cache: bool = True, no_cusolver: bool = False, ) -> None: @@ -368,6 +369,7 @@ def __init__( self.dispatch_jobs = dispatch_jobs self.allow_fallback = allow_fallback self.platform_config = get_platform(target_platform) + self.ignore_router_config = ignore_router_config self.use_router_cache = use_router_cache self.no_cusolver = no_cusolver @@ -528,8 +530,8 @@ def solve(self, problem_path: Path) -> RouteResult: # Confidence too low or invalid JSON; resort to heuristic strategy = "fuser" if heuristic_prefers_fuser else "kernelagent" - # Apply optional dynamic config from router - if isinstance(route_cfg, dict): + # Apply optional dynamic config from router (skip if ignore requested) + if isinstance(route_cfg, dict) and not self.ignore_router_config: # KernelAgent tuning self.ka_max_rounds = int(route_cfg.get("ka_max_rounds", self.ka_max_rounds)) self.ka_num_workers = int( @@ -729,6 +731,11 @@ def main(argv: list[str] | None = None) -> int: p.add_argument("--verify", action="store_true") p.add_argument("--dispatch-jobs", type=int, default=2) p.add_argument("--no-fallback", action="store_true") + p.add_argument( + "--ignore-router-config", + action="store_true", + help="Ignore router config. Use CLI-provided model/config arguments", + ) p.add_argument( "--no-router-cache", action="store_true", @@ -776,6 +783,7 @@ def main(argv: list[str] | None = None) -> int: dispatch_jobs=args.dispatch_jobs, allow_fallback=(not args.no_fallback), target_platform=args.target_platform, + ignore_router_config=args.ignore_router_config, use_router_cache=(not args.no_router_cache), no_cusolver=args.no_cusolver, ) diff --git a/README.md b/README.md index 26302a5..18792af 100644 --- a/README.md +++ b/README.md @@ -146,7 +146,7 @@ More knobs live in `triton_kernel_agent/agent.py` and `Fuser/config.py`. ## Component Details -- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails. Use `--no-router-cache` ignore the existing cache and caching new routes. +- **AutoRouter (`Fuser/auto_agent.py`)**: parses the problem’s AST, looks for attention blocks, transposed convolutions, control flow, and long op chains. It caches decisions under `.fuse/router_cache.json` and can fall back to the other path if the first attempt fails. Use `--no-router-cache` to ignore the existing cache and caching new routes. Use `--ignore-router-config` to ignore router-provided tuning and rely on CLI args. - **Fuser Orchestrator (`Fuser/orchestrator.py`)**: rewrites the PyTorch module into fusable modules, executes them for validation, and packages a tarball of the fused code. Run IDs and directories are managed via `Fuser/paths.py`.