Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 100 additions & 9 deletions src/evalml/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import base64
import shlex
import subprocess
import urllib.error
import urllib.request
import zlib
from pathlib import Path
from typing import Any

Expand All @@ -15,6 +19,56 @@ def run_command(command: list[str]) -> int:
return subprocess.run(command).returncode


def _base_snakemake_command(
config: ConfigModel, configfile: Path, cores: int
) -> list[str]:
command = ["snakemake"]
command += config.profile.parsable()
command += ["--configfile", str(configfile)]
command += ["--cores", str(cores)]
return command


def _dot_to_svg(dot_content: str) -> bytes:
compressed = zlib.compress(dot_content.encode(), 9)
encoded = base64.urlsafe_b64encode(compressed).decode()
req = urllib.request.Request(
f"https://kroki.io/graphviz/svg/{encoded}",
headers={"User-Agent": "curl/7.68.0"},
)
try:
with urllib.request.urlopen(req) as response:
return response.read()
except urllib.error.HTTPError as e:
raise click.ClickException(
f"kroki.io request failed: {e.code} {e.reason}"
) from e
except urllib.error.URLError as e:
raise click.ClickException(f"kroki.io request failed: {e.reason}") from e


def generate_graph(
configfile: Path,
target: str,
graph_type: str,
cores: int,
extra_smk_args: tuple[str, ...] = (),
) -> None:
config = ConfigModel.model_validate(load_yaml(configfile))
command = _base_snakemake_command(config, configfile, cores)
command += [f"--{graph_type}", "dot", target]
command += list(extra_smk_args)

result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
click.echo(result.stderr, err=True)
raise SystemExit(result.returncode)

output_file = Path(f"{graph_type}.svg")
output_file.write_bytes(_dot_to_svg(result.stdout))
click.echo(f"Graph saved to {output_file}")


def load_yaml(path: Path) -> dict[str, Any]:
with path.open("r") as f:
return yaml.safe_load(f)
Expand Down Expand Up @@ -47,6 +101,14 @@ def workflow_options(func):
is_flag=False,
flag_value=f"{command_name}_report.html",
)(func)
func = click.option(
"--dag", is_flag=True, help="Generate a DAG and save as dag.svg."
)(func)
func = click.option(
"--rulegraph",
is_flag=True,
help="Generate a rule graph and save as rulegraph.svg.",
)(func)
func = click.argument(
"extra_smk_args",
nargs=-1,
Expand All @@ -64,14 +126,18 @@ def execute_workflow(
dry_run: bool,
unlock: bool,
report: Path | None,
dag: bool = False,
rulegraph: bool = False,
extra_smk_args: tuple[str, ...] = (),
):
config = ConfigModel.model_validate(load_yaml(configfile))
if dag or rulegraph:
generate_graph(
configfile, target, "dag" if dag else "rulegraph", cores, extra_smk_args
)
return

command = ["snakemake"]
command += config.profile.parsable()
command += ["--configfile", str(configfile)]
command += ["--cores", str(cores)]
config = ConfigModel.model_validate(load_yaml(configfile))
command = _base_snakemake_command(config, configfile, cores)

if dry_run:
command.append("--dry-run")
Expand All @@ -98,7 +164,9 @@ def cli():
"configfile", type=click.Path(exists=True, dir_okay=False, path_type=Path)
)
@workflow_options
def experiment(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args):
def experiment(
configfile, cores, verbose, dry_run, unlock, report, dag, rulegraph, extra_smk_args
):
execute_workflow(
configfile,
"experiment_all",
Expand All @@ -107,6 +175,8 @@ def experiment(configfile, cores, verbose, dry_run, unlock, report, extra_smk_ar
dry_run,
unlock,
report,
dag,
rulegraph,
extra_smk_args,
)

Expand All @@ -116,7 +186,9 @@ def experiment(configfile, cores, verbose, dry_run, unlock, report, extra_smk_ar
"configfile", type=click.Path(exists=True, dir_okay=False, path_type=Path)
)
@workflow_options
def showcase(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args):
def showcase(
configfile, cores, verbose, dry_run, unlock, report, dag, rulegraph, extra_smk_args
):
execute_workflow(
configfile,
"showcase_all",
Expand All @@ -125,6 +197,8 @@ def showcase(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args
dry_run,
unlock,
report,
dag,
rulegraph,
extra_smk_args,
)

Expand All @@ -134,7 +208,9 @@ def showcase(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args
"configfile", type=click.Path(exists=True, dir_okay=False, path_type=Path)
)
@workflow_options
def sandbox(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args):
def sandbox(
configfile, cores, verbose, dry_run, unlock, report, dag, rulegraph, extra_smk_args
):
execute_workflow(
configfile,
"sandbox_all",
Expand All @@ -143,6 +219,8 @@ def sandbox(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args)
dry_run,
unlock,
report,
dag,
rulegraph,
extra_smk_args,
)

Expand All @@ -153,7 +231,18 @@ def sandbox(configfile, cores, verbose, dry_run, unlock, report, extra_smk_args)
)
@click.argument("target", type=str)
@workflow_options
def make(configfile, target, cores, verbose, dry_run, unlock, report, extra_smk_args):
def make(
configfile,
target,
cores,
verbose,
dry_run,
unlock,
report,
dag,
rulegraph,
extra_smk_args,
):
execute_workflow(
configfile,
target,
Expand All @@ -162,5 +251,7 @@ def make(configfile, target, cores, verbose, dry_run, unlock, report, extra_smk_
dry_run,
unlock,
report,
dag,
rulegraph,
extra_smk_args,
)
Loading