diff --git a/src/evalml/cli.py b/src/evalml/cli.py index 839e7a21..51a9ed45 100644 --- a/src/evalml/cli.py +++ b/src/evalml/cli.py @@ -1,5 +1,9 @@ +import base64 import shlex import subprocess +import urllib.error +import urllib.request +import zlib from pathlib import Path from typing import Any @@ -15,6 +19,56 @@ def run_command(command: list[str]) -> int: return subprocess.run(command).returncode +def _base_snakemake_command( + config: ConfigModel, configfile: Path, cores: int +) -> list[str]: + command = ["snakemake"] + command += config.profile.parsable() + command += ["--configfile", str(configfile)] + command += ["--cores", str(cores)] + return command + + +def _dot_to_svg(dot_content: str) -> bytes: + compressed = zlib.compress(dot_content.encode(), 9) + encoded = base64.urlsafe_b64encode(compressed).decode() + req = urllib.request.Request( + f"https://kroki.io/graphviz/svg/{encoded}", + headers={"User-Agent": "curl/7.68.0"}, + ) + try: + with urllib.request.urlopen(req) as response: + return response.read() + except urllib.error.HTTPError as e: + raise click.ClickException( + f"kroki.io request failed: {e.code} {e.reason}" + ) from e + except urllib.error.URLError as e: + raise click.ClickException(f"kroki.io request failed: {e.reason}") from e + + +def generate_graph( + configfile: Path, + target: str, + graph_type: str, + cores: int, + extra_smk_args: tuple[str, ...] = (), +) -> None: + config = ConfigModel.model_validate(load_yaml(configfile)) + command = _base_snakemake_command(config, configfile, cores) + command += [f"--{graph_type}", "dot", target] + command += list(extra_smk_args) + + result = subprocess.run(command, capture_output=True, text=True) + if result.returncode != 0: + click.echo(result.stderr, err=True) + raise SystemExit(result.returncode) + + output_file = Path(f"{graph_type}.svg") + output_file.write_bytes(_dot_to_svg(result.stdout)) + click.echo(f"Graph saved to {output_file}") + + def load_yaml(path: Path) -> dict[str, Any]: with path.open("r") as f: return yaml.safe_load(f) @@ -47,6 +101,14 @@ def workflow_options(func): is_flag=False, flag_value=f"{command_name}_report.html", )(func) + func = click.option( + "--dag", is_flag=True, help="Generate a DAG and save as dag.svg." + )(func) + func = click.option( + "--rulegraph", + is_flag=True, + help="Generate a rule graph and save as rulegraph.svg.", + )(func) func = click.argument( "extra_smk_args", nargs=-1, @@ -64,14 +126,18 @@ def execute_workflow( dry_run: bool, unlock: bool, report: Path | None, + dag: bool = False, + rulegraph: bool = False, extra_smk_args: tuple[str, ...] = (), ): - config = ConfigModel.model_validate(load_yaml(configfile)) + if dag or rulegraph: + generate_graph( + configfile, target, "dag" if dag else "rulegraph", cores, extra_smk_args + ) + return - command = ["snakemake"] - command += config.profile.parsable() - command += ["--configfile", str(configfile)] - command += ["--cores", str(cores)] + config = ConfigModel.model_validate(load_yaml(configfile)) + command = _base_snakemake_command(config, configfile, cores) if dry_run: command.append("--dry-run") @@ -98,7 +164,9 @@ def cli(): "configfile", type=click.Path(exists=True, dir_okay=False, path_type=Path) ) @workflow_options -def experiment(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args): +def experiment( + configfile, cores, verbose, dry_run, unlock, report, dag, rulegraph, extra_smk_args +): execute_workflow( configfile, "experiment_all", @@ -107,6 +175,8 @@ def experiment(configfile, cores, verbose, dry_run, unlock, report, extra_smk_ar dry_run, unlock, report, + dag, + rulegraph, extra_smk_args, ) @@ -116,7 +186,9 @@ def experiment(configfile, cores, verbose, dry_run, unlock, report, extra_smk_ar "configfile", type=click.Path(exists=True, dir_okay=False, path_type=Path) ) @workflow_options -def showcase(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args): +def showcase( + configfile, cores, verbose, dry_run, unlock, report, dag, rulegraph, extra_smk_args +): execute_workflow( configfile, "showcase_all", @@ -125,6 +197,8 @@ def showcase(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args dry_run, unlock, report, + dag, + rulegraph, extra_smk_args, ) @@ -134,7 +208,9 @@ def showcase(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args "configfile", type=click.Path(exists=True, dir_okay=False, path_type=Path) ) @workflow_options -def sandbox(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args): +def sandbox( + configfile, cores, verbose, dry_run, unlock, report, dag, rulegraph, extra_smk_args +): execute_workflow( configfile, "sandbox_all", @@ -143,6 +219,8 @@ def sandbox(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args) dry_run, unlock, report, + dag, + rulegraph, extra_smk_args, ) @@ -153,7 +231,18 @@ def sandbox(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args) ) @click.argument("target", type=str) @workflow_options -def make(configfile, target, cores, verbose, dry_run, unlock, report, extra_smk_args): +def make( + configfile, + target, + cores, + verbose, + dry_run, + unlock, + report, + dag, + rulegraph, + extra_smk_args, +): execute_workflow( configfile, target, @@ -162,5 +251,7 @@ def make(configfile, target, cores, verbose, dry_run, unlock, report, extra_smk_ dry_run, unlock, report, + dag, + rulegraph, extra_smk_args, )